Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5ea8ba67
Unverified
Commit
5ea8ba67
authored
Mar 15, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 15, 2020
Browse files
[BART] Remove unused kwargs (#3279)
* Remove unused kwargs * dont call forward in tests
parent
3814e167
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
29 deletions
+14
-29
examples/summarization/bertabs/modeling_bertabs.py
examples/summarization/bertabs/modeling_bertabs.py
+1
-1
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+4
-19
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+9
-9
No files found.
examples/summarization/bertabs/modeling_bertabs.py
View file @
5ea8ba67
...
@@ -844,7 +844,7 @@ class Translator(object):
...
@@ -844,7 +844,7 @@ class Translator(object):
dec_out
,
dec_states
=
self
.
model
.
decoder
(
decoder_input
,
src_features
,
dec_states
,
step
=
step
)
dec_out
,
dec_states
=
self
.
model
.
decoder
(
decoder_input
,
src_features
,
dec_states
,
step
=
step
)
# Generator forward.
# Generator forward.
log_probs
=
self
.
generator
.
forward
(
dec_out
.
transpose
(
0
,
1
).
squeeze
(
0
))
log_probs
=
self
.
generator
(
dec_out
.
transpose
(
0
,
1
).
squeeze
(
0
))
vocab_size
=
log_probs
.
size
(
-
1
)
vocab_size
=
log_probs
.
size
(
-
1
)
if
step
<
min_length
:
if
step
<
min_length
:
...
...
src/transformers/modeling_bart.py
View file @
5ea8ba67
...
@@ -223,9 +223,7 @@ class EncoderLayer(nn.Module):
...
@@ -223,9 +223,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)`
encoded output of shape `(seq_len, batch, embed_dim)`
"""
"""
residual
=
x
residual
=
x
x
,
attn_weights
=
self
.
self_attn
(
x
,
attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,)
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
self
.
output_attentions
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
x
=
self
.
self_attn_layer_norm
(
x
)
...
@@ -378,7 +376,7 @@ class DecoderLayer(nn.Module):
...
@@ -378,7 +376,7 @@ class DecoderLayer(nn.Module):
layer_state
=
{}
layer_state
=
{}
# next line mutates layer state
# next line mutates layer state
x
,
self_attn_weights
=
self
.
self_attn
(
x
,
self_attn_weights
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
layer_state
=
layer_state
,
need_weights
=
need_attn_weights
,
attn_mask
=
attention_mask
,
query
=
x
,
key
=
y
,
value
=
y
,
layer_state
=
layer_state
,
attn_mask
=
attention_mask
,
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
...
@@ -393,7 +391,6 @@ class DecoderLayer(nn.Module):
...
@@ -393,7 +391,6 @@ class DecoderLayer(nn.Module):
key_padding_mask
=
encoder_attn_mask
,
key_padding_mask
=
encoder_attn_mask
,
layer_state
=
layer_state
,
# mutates layer state
layer_state
=
layer_state
,
# mutates layer state
static_kv
=
True
,
static_kv
=
True
,
need_weights
=
False
,
# not returning it so why compute it
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
...
@@ -548,16 +545,12 @@ class SelfAttention(nn.Module):
...
@@ -548,16 +545,12 @@ class SelfAttention(nn.Module):
self
,
self
,
embed_dim
,
embed_dim
,
num_heads
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
dropout
=
0.0
,
bias
=
True
,
bias
=
True
,
encoder_decoder_attention
=
False
,
# otherwise self_attention
encoder_decoder_attention
=
False
,
# otherwise self_attention
):
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
kdim
=
kdim
if
kdim
is
not
None
else
embed_dim
self
.
vdim
=
vdim
if
vdim
is
not
None
else
embed_dim
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
dropout
=
dropout
...
@@ -566,13 +559,8 @@ class SelfAttention(nn.Module):
...
@@ -566,13 +559,8 @@ class SelfAttention(nn.Module):
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
encoder_decoder_attention
=
encoder_decoder_attention
self
.
encoder_decoder_attention
=
encoder_decoder_attention
qkv_same_dim
=
self
.
kdim
==
embed_dim
and
self
.
vdim
==
embed_dim
# True for all BART
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
assert
self
.
encoder_decoder_attention
or
qkv_same_dim
,
(
"Self-attention requires query, key and "
"value to be of the same size"
)
self
.
k_proj
=
nn
.
Linear
(
self
.
kdim
,
embed_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
self
.
vdim
,
embed_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
cache_key
=
"encoder_decoder"
if
self
.
encoder_decoder_attention
else
"self"
self
.
cache_key
=
"encoder_decoder"
if
self
.
encoder_decoder_attention
else
"self"
...
@@ -587,7 +575,6 @@ class SelfAttention(nn.Module):
...
@@ -587,7 +575,6 @@ class SelfAttention(nn.Module):
value
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
layer_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
layer_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
need_weights
:
bool
=
False
,
static_kv
:
bool
=
False
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
...
@@ -598,8 +585,6 @@ class SelfAttention(nn.Module):
...
@@ -598,8 +585,6 @@ class SelfAttention(nn.Module):
key_padding_mask (ByteTensor, optional): mask to exclude
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
attention from looking forward in time (default: None).
...
...
tests/test_modeling_bart.py
View file @
5ea8ba67
...
@@ -141,13 +141,13 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -141,13 +141,13 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
_check_var
(
model
.
encoder
.
layers
[
0
].
fc1
)
_check_var
(
model
.
encoder
.
layers
[
0
].
fc1
)
_check_var
(
model
.
encoder
.
embed_positions
)
_check_var
(
model
.
encoder
.
embed_positions
)
decoder_features_with_created_mask
=
model
.
forward
(
**
inputs_dict
)[
0
]
decoder_features_with_created_mask
=
model
(
**
inputs_dict
)[
0
]
decoder_features_with_passed_mask
=
model
.
forward
(
decoder_features_with_passed_mask
=
model
(
decoder_attention_mask
=
decoder_attn_mask
,
decoder_input_ids
=
decoder_input_ids
,
**
inputs_dict
decoder_attention_mask
=
decoder_attn_mask
,
decoder_input_ids
=
decoder_input_ids
,
**
inputs_dict
)[
0
]
)[
0
]
_assert_tensors_equal
(
decoder_features_with_passed_mask
,
decoder_features_with_created_mask
)
_assert_tensors_equal
(
decoder_features_with_passed_mask
,
decoder_features_with_created_mask
)
useless_mask
=
torch
.
zeros_like
(
decoder_attn_mask
)
useless_mask
=
torch
.
zeros_like
(
decoder_attn_mask
)
decoder_features
=
model
.
forward
(
decoder_attention_mask
=
useless_mask
,
**
inputs_dict
)[
0
]
decoder_features
=
model
(
decoder_attention_mask
=
useless_mask
,
**
inputs_dict
)[
0
]
self
.
assertTrue
(
isinstance
(
decoder_features
,
torch
.
Tensor
))
# no hidden states or attentions
self
.
assertTrue
(
isinstance
(
decoder_features
,
torch
.
Tensor
))
# no hidden states or attentions
self
.
assertEqual
(
self
.
assertEqual
(
decoder_features
.
size
(),
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
seq_length
,
config
.
d_model
)
decoder_features
.
size
(),
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
seq_length
,
config
.
d_model
)
...
@@ -156,7 +156,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -156,7 +156,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
self
.
assertFalse
((
decoder_features_with_created_mask
==
decoder_features
).
all
().
item
())
self
.
assertFalse
((
decoder_features_with_created_mask
==
decoder_features
).
all
().
item
())
# Test different encoder attention masks
# Test different encoder attention masks
decoder_features_with_long_encoder_mask
=
model
.
forward
(
decoder_features_with_long_encoder_mask
=
model
(
inputs_dict
[
"input_ids"
],
attention_mask
=
inputs_dict
[
"attention_mask"
].
long
()
inputs_dict
[
"input_ids"
],
attention_mask
=
inputs_dict
[
"attention_mask"
].
long
()
)[
0
]
)[
0
]
_assert_tensors_equal
(
decoder_features_with_long_encoder_mask
,
decoder_features_with_created_mask
)
_assert_tensors_equal
(
decoder_features_with_long_encoder_mask
,
decoder_features_with_created_mask
)
...
@@ -237,7 +237,7 @@ class BartHeadTests(unittest.TestCase):
...
@@ -237,7 +237,7 @@ class BartHeadTests(unittest.TestCase):
decoder_lm_labels
=
ids_tensor
([
batch_size
,
input_ids
.
shape
[
1
]],
self
.
vocab_size
).
to
(
torch_device
)
decoder_lm_labels
=
ids_tensor
([
batch_size
,
input_ids
.
shape
[
1
]],
self
.
vocab_size
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
)
lm_model
=
BartForConditionalGeneration
(
config
)
lm_model
.
to
(
torch_device
)
lm_model
.
to
(
torch_device
)
loss
,
logits
,
enc_features
=
lm_model
.
forward
(
loss
,
logits
,
enc_features
=
lm_model
(
input_ids
=
input_ids
,
lm_labels
=
decoder_lm_labels
,
decoder_input_ids
=
input_ids
input_ids
=
input_ids
,
lm_labels
=
decoder_lm_labels
,
decoder_input_ids
=
input_ids
)
)
expected_shape
=
(
batch_size
,
input_ids
.
shape
[
1
],
config
.
vocab_size
)
expected_shape
=
(
batch_size
,
input_ids
.
shape
[
1
],
config
.
vocab_size
)
...
@@ -259,7 +259,7 @@ class BartHeadTests(unittest.TestCase):
...
@@ -259,7 +259,7 @@ class BartHeadTests(unittest.TestCase):
lm_model
=
BartForConditionalGeneration
(
config
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
).
to
(
torch_device
)
context
=
torch
.
Tensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]]).
long
().
to
(
torch_device
)
context
=
torch
.
Tensor
([[
71
,
82
,
18
,
33
,
46
,
91
,
2
],
[
68
,
34
,
26
,
58
,
30
,
2
,
1
]]).
long
().
to
(
torch_device
)
summary
=
torch
.
Tensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]]).
long
().
to
(
torch_device
)
summary
=
torch
.
Tensor
([[
82
,
71
,
82
,
18
,
2
],
[
58
,
68
,
2
,
1
,
1
]]).
long
().
to
(
torch_device
)
loss
,
logits
,
enc_features
=
lm_model
.
forward
(
input_ids
=
context
,
decoder_input_ids
=
summary
,
lm_labels
=
summary
)
loss
,
logits
,
enc_features
=
lm_model
(
input_ids
=
context
,
decoder_input_ids
=
summary
,
lm_labels
=
summary
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
expected_shape
=
(
*
summary
.
shape
,
config
.
vocab_size
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
logits
.
shape
,
expected_shape
)
...
@@ -388,7 +388,7 @@ class BartModelIntegrationTest(unittest.TestCase):
...
@@ -388,7 +388,7 @@ class BartModelIntegrationTest(unittest.TestCase):
input_ids
=
_long_tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
input_ids
=
_long_tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
)
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
model
.
forward
(
**
inputs_dict
)[
0
]
output
=
model
(
**
inputs_dict
)[
0
]
expected_shape
=
torch
.
Size
((
1
,
11
,
1024
))
expected_shape
=
torch
.
Size
((
1
,
11
,
1024
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
torch
.
tensor
(
...
@@ -408,7 +408,7 @@ class BartModelIntegrationTest(unittest.TestCase):
...
@@ -408,7 +408,7 @@ class BartModelIntegrationTest(unittest.TestCase):
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
)
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
)
# Test that model hasn't changed
# Test that model hasn't changed
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batched_logits
,
features
=
model
.
forward
(
**
inputs_dict
)
batched_logits
,
features
=
model
(
**
inputs_dict
)
expected_shape
=
torch
.
Size
((
2
,
3
))
expected_shape
=
torch
.
Size
((
2
,
3
))
self
.
assertEqual
(
batched_logits
.
shape
,
expected_shape
)
self
.
assertEqual
(
batched_logits
.
shape
,
expected_shape
)
expected_slice
=
torch
.
Tensor
([[
0.1907
,
1.4342
,
-
1.0289
]]).
to
(
torch_device
)
expected_slice
=
torch
.
Tensor
([[
0.1907
,
1.4342
,
-
1.0289
]]).
to
(
torch_device
)
...
@@ -419,7 +419,7 @@ class BartModelIntegrationTest(unittest.TestCase):
...
@@ -419,7 +419,7 @@ class BartModelIntegrationTest(unittest.TestCase):
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
=
input_ids_no_pad
)
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
=
input_ids_no_pad
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits2
=
model
.
forward
(
**
inputs_dict
)[
0
]
logits2
=
model
(
**
inputs_dict
)[
0
]
_assert_tensors_equal
(
batched_logits
[
1
],
logits2
,
atol
=
TOLERANCE
)
_assert_tensors_equal
(
batched_logits
[
1
],
logits2
,
atol
=
TOLERANCE
)
_assert_tensors_equal
(
expected_slice
,
logits_arr
,
atol
=
TOLERANCE
)
_assert_tensors_equal
(
expected_slice
,
logits_arr
,
atol
=
TOLERANCE
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment