Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2bd79e23
"docs/vscode:/vscode.git/clone" did not exist on "5c82bf6831b49e1e6029d09488081d5d98a272e9"
Unverified
Commit
2bd79e23
authored
Mar 13, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 13, 2020
Browse files
[BART] FP16 testing fixes (#3266)
parent
8320feec
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
4 deletions
+16
-4
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+8
-2
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+8
-2
No files found.
src/transformers/modeling_bart.py
View file @
2bd79e23
...
...
@@ -82,7 +82,7 @@ LARGE_NEGATIVE = -1e8
def
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
mask_dtype
=
None
,
):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
...
...
@@ -101,6 +101,8 @@ def _prepare_bart_decoder_inputs(
new_shape
=
(
bsz
,
tgt_len
,
tgt_len
)
# make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask
=
_combine_masks
(
decoder_padding_mask
,
causal_lm_mask
,
new_shape
).
to
(
device
=
input_ids
.
device
)
if
mask_dtype
is
not
None
:
decoder_attn_mask
=
decoder_attn_mask
.
to
(
mask_dtype
)
assert
decoder_attn_mask
is
None
or
decoder_attn_mask
.
shape
==
(
bsz
,
1
,
tgt_len
,
tgt_len
)
return
decoder_input_ids
,
decoder_attn_mask
...
...
@@ -838,7 +840,11 @@ class BartModel(PretrainedBartModel):
# make masks if user doesn't supply
if
not
self
.
decoder
.
generation_mode
:
decoder_input_ids
,
decoder_attention_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attn_mask
=
decoder_attention_mask
,
self
.
config
,
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attn_mask
=
decoder_attention_mask
,
mask_dtype
=
self
.
shared
.
weight
.
dtype
,
)
assert
decoder_input_ids
is
not
None
if
encoder_outputs
is
None
:
...
...
tests/test_modeling_bart.py
View file @
2bd79e23
...
...
@@ -314,10 +314,16 @@ class BartHeadTests(unittest.TestCase):
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_generate_fp16
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
(
output_past
=
True
)
input_ids
=
input_ids
attention_mask
=
input_ids
.
ne
(
1
).
to
(
torch_device
)
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
do_sample
=
False
,
early_stopping
=
True
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_base_model_fp16
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
(
output_past
=
False
)
attention_mask
=
input_ids
.
ne
(
1
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
)
lm_model
(
input_ids
,
attention_mask
=
attention_mask
)
def
test_prepare_bart_decoder_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment