Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e4403c9
Unverified
Commit
4e4403c9
authored
Mar 19, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 19, 2020
Browse files
[BART] torch 1.0 compatibility (#3322)
* config.activation_function
parent
c44a17db
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
41 deletions
+25
-41
src/transformers/activations.py
src/transformers/activations.py
+1
-5
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+2
-0
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+22
-36
No files found.
src/transformers/activations.py
View file @
4e4403c9
...
@@ -44,8 +44,4 @@ def get_activation(activation_string):
...
@@ -44,8 +44,4 @@ def get_activation(activation_string):
if
activation_string
in
ACT2FN
:
if
activation_string
in
ACT2FN
:
return
ACT2FN
[
activation_string
]
return
ACT2FN
[
activation_string
]
else
:
else
:
raise
KeyError
(
raise
KeyError
(
"function {} not found in ACT2FN mapping {}"
.
format
(
activation_string
,
list
(
ACT2FN
.
keys
())))
"function {} not found in ACT2FN mapping {} or torch.nn.functional"
.
format
(
activation_string
,
list
(
ACT2FN
.
keys
())
)
)
src/transformers/configuration_bart.py
View file @
4e4403c9
...
@@ -39,6 +39,7 @@ class BartConfig(PretrainedConfig):
...
@@ -39,6 +39,7 @@ class BartConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
activation_dropout
=
0.0
,
activation_dropout
=
0.0
,
activation_function
=
"gelu"
,
vocab_size
=
50265
,
vocab_size
=
50265
,
bos_token_id
=
0
,
bos_token_id
=
0
,
pad_token_id
=
1
,
pad_token_id
=
1
,
...
@@ -89,6 +90,7 @@ class BartConfig(PretrainedConfig):
...
@@ -89,6 +90,7 @@ class BartConfig(PretrainedConfig):
self
.
decoder_attention_heads
=
decoder_attention_heads
self
.
decoder_attention_heads
=
decoder_attention_heads
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
init_std
=
init_std
# Normal(0, this parameter)
self
.
init_std
=
init_std
# Normal(0, this parameter)
self
.
activation_function
=
activation_function
# 3 Types of Dropout
# 3 Types of Dropout
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
...
...
src/transformers/modeling_bart.py
View file @
4e4403c9
...
@@ -21,6 +21,7 @@ import torch
...
@@ -21,6 +21,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
.activations
import
ACT2FN
from
.configuration_bart
import
BartConfig
from
.configuration_bart
import
BartConfig
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_utils
import
PreTrainedModel
,
create_position_ids_from_input_ids
from
.modeling_utils
import
PreTrainedModel
,
create_position_ids_from_input_ids
...
@@ -196,7 +197,7 @@ class EncoderLayer(nn.Module):
...
@@ -196,7 +197,7 @@ class EncoderLayer(nn.Module):
)
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
self
.
activation_fn
=
F
.
gelu
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
activation_dropout
=
config
.
activation_dropout
self
.
activation_dropout
=
config
.
activation_dropout
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
)
...
@@ -278,8 +279,8 @@ class BartEncoder(nn.Module):
...
@@ -278,8 +279,8 @@ class BartEncoder(nn.Module):
# check attention mask and invert
# check attention mask and invert
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
assert
attention_mask
.
dim
()
==
2
assert
attention_mask
.
dim
()
==
2
attention_mask
=
(
1.0
-
attention_mask
.
long
())
*
LARGE_NEGATIVE
attention_mask
=
attention_mask
.
eq
(
0
)
assert
attention_mask
.
max
()
<=
0
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
input_ids
)
x
=
inputs_embeds
+
embed_pos
x
=
inputs_embeds
+
embed_pos
...
@@ -318,7 +319,7 @@ class DecoderLayer(nn.Module):
...
@@ -318,7 +319,7 @@ class DecoderLayer(nn.Module):
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
decoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
decoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
)
)
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
self
.
activation_fn
=
F
.
gelu
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
activation_dropout
=
config
.
activation_dropout
self
.
activation_dropout
=
config
.
activation_dropout
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
...
@@ -334,13 +335,7 @@ class DecoderLayer(nn.Module):
...
@@ -334,13 +335,7 @@ class DecoderLayer(nn.Module):
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
def
forward
(
def
forward
(
self
,
self
,
x
,
encoder_hidden_states
,
encoder_attn_mask
=
None
,
layer_state
=
None
,
attention_mask
=
None
,
x
,
encoder_hidden_states
,
encoder_attn_mask
=
None
,
layer_state
=
None
,
attention_mask
=
None
,
need_attn_weights
=
False
,
):
):
residual
=
x
residual
=
x
...
@@ -437,9 +432,7 @@ class BartDecoder(nn.Module):
...
@@ -437,9 +432,7 @@ class BartDecoder(nn.Module):
# check attention mask and invert
# check attention mask and invert
if
encoder_padding_mask
is
not
None
:
if
encoder_padding_mask
is
not
None
:
assert
encoder_padding_mask
.
dim
()
==
2
assert
encoder_padding_mask
.
dim
()
==
2
encoder_padding_mask
=
encoder_padding_mask
.
eq
(
0
)
encoder_padding_mask
=
(
1.0
-
encoder_padding_mask
.
long
())
*
-
10000.0
assert
encoder_padding_mask
.
max
()
<=
0
# embed positions
# embed positions
positions
=
self
.
embed_positions
(
input_ids
,
generation_mode
=
generation_mode
)
positions
=
self
.
embed_positions
(
input_ids
,
generation_mode
=
generation_mode
)
...
@@ -469,12 +462,7 @@ class BartDecoder(nn.Module):
...
@@ -469,12 +462,7 @@ class BartDecoder(nn.Module):
layer_state
=
decoder_cached_states
[
i
]
if
decoder_cached_states
is
not
None
else
None
layer_state
=
decoder_cached_states
[
i
]
if
decoder_cached_states
is
not
None
else
None
x
,
layer_self_attn
,
layer_past
=
decoder_layer
(
x
,
layer_self_attn
,
layer_past
=
decoder_layer
(
x
,
x
,
encoder_hidden_states
,
encoder_padding_mask
,
layer_state
=
layer_state
,
attention_mask
=
combined_mask
,
encoder_hidden_states
,
encoder_padding_mask
,
layer_state
=
layer_state
,
attention_mask
=
combined_mask
,
need_attn_weights
=
self
.
output_attentions
,
)
)
if
self
.
output_past
:
if
self
.
output_past
:
...
@@ -598,7 +586,7 @@ class SelfAttention(nn.Module):
...
@@ -598,7 +586,7 @@ class SelfAttention(nn.Module):
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
reshaped
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
.
to
(
torch
.
bool
)
reshaped
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
attn_weights
=
attn_weights
.
masked_fill
(
reshaped
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
masked_fill
(
reshaped
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
...
@@ -648,22 +636,20 @@ class SelfAttention(nn.Module):
...
@@ -648,22 +636,20 @@ class SelfAttention(nn.Module):
static_kv
:
bool
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
and
static_kv
:
if
prev_key_padding_mask
is
not
None
:
new_key_padding_mask
=
prev_key_padding_mask
if
static_kv
:
elif
prev_key_padding_mask
is
not
None
and
key_padding_mask
is
not
None
:
new_key_padding_mask
=
prev_key_padding_mask
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
else
:
# During incremental decoding, as the padding token enters and
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
,
key_padding_mask
],
dim
=
1
)
# leaves the frame, there will be a time when prev or current is None
elif
prev_key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
prev_key_padding_mask
.
size
(
1
))
if
prev_key_padding_mask
.
is_cuda
:
filler
=
filler
.
to
(
prev_key_padding_mask
.
device
)
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
))
filler
=
torch
.
zeros
(
if
key_padding_mask
.
is_cuda
:
batch_size
,
filler
=
filler
.
cuda
()
src_len
-
key_padding_mask
.
size
(
1
),
new_key_padding_mask
=
torch
.
cat
([
filler
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
dtype
=
key_padding_mask
.
dtype
,
device
=
key_padding_mask
.
device
,
)
new_key_padding_mask
=
torch
.
cat
([
filler
,
key_padding_mask
],
dim
=
1
)
else
:
else
:
new_key_padding_mask
=
prev_key_padding_mask
new_key_padding_mask
=
prev_key_padding_mask
return
new_key_padding_mask
return
new_key_padding_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment