Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2bd79e23
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8337978f754030e142123e7360742661bc52c47c"
Unverified
Commit
2bd79e23
authored
Mar 13, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 13, 2020
Browse files
[BART] FP16 testing fixes (#3266)
parent
8320feec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
4 deletions
+16
-4
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+8
-2
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+8
-2
No files found.
src/transformers/modeling_bart.py
View file @
2bd79e23
...
@@ -82,7 +82,7 @@ LARGE_NEGATIVE = -1e8
...
@@ -82,7 +82,7 @@ LARGE_NEGATIVE = -1e8
def
_prepare_bart_decoder_inputs
(
def
_prepare_bart_decoder_inputs
(
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
config
,
input_ids
,
decoder_input_ids
=
None
,
decoder_attn_mask
=
None
,
mask_dtype
=
None
,
):
):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
...
@@ -101,6 +101,8 @@ def _prepare_bart_decoder_inputs(
...
@@ -101,6 +101,8 @@ def _prepare_bart_decoder_inputs(
new_shape
=
(
bsz
,
tgt_len
,
tgt_len
)
new_shape
=
(
bsz
,
tgt_len
,
tgt_len
)
# make it broadcastable so can just be added to the attention coefficients
# make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask
=
_combine_masks
(
decoder_padding_mask
,
causal_lm_mask
,
new_shape
).
to
(
device
=
input_ids
.
device
)
decoder_attn_mask
=
_combine_masks
(
decoder_padding_mask
,
causal_lm_mask
,
new_shape
).
to
(
device
=
input_ids
.
device
)
if
mask_dtype
is
not
None
:
decoder_attn_mask
=
decoder_attn_mask
.
to
(
mask_dtype
)
assert
decoder_attn_mask
is
None
or
decoder_attn_mask
.
shape
==
(
bsz
,
1
,
tgt_len
,
tgt_len
)
assert
decoder_attn_mask
is
None
or
decoder_attn_mask
.
shape
==
(
bsz
,
1
,
tgt_len
,
tgt_len
)
return
decoder_input_ids
,
decoder_attn_mask
return
decoder_input_ids
,
decoder_attn_mask
...
@@ -838,7 +840,11 @@ class BartModel(PretrainedBartModel):
...
@@ -838,7 +840,11 @@ class BartModel(PretrainedBartModel):
# make masks if user doesn't supply
# make masks if user doesn't supply
if
not
self
.
decoder
.
generation_mode
:
if
not
self
.
decoder
.
generation_mode
:
decoder_input_ids
,
decoder_attention_mask
=
_prepare_bart_decoder_inputs
(
decoder_input_ids
,
decoder_attention_mask
=
_prepare_bart_decoder_inputs
(
self
.
config
,
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attn_mask
=
decoder_attention_mask
,
self
.
config
,
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attn_mask
=
decoder_attention_mask
,
mask_dtype
=
self
.
shared
.
weight
.
dtype
,
)
)
assert
decoder_input_ids
is
not
None
assert
decoder_input_ids
is
not
None
if
encoder_outputs
is
None
:
if
encoder_outputs
is
None
:
...
...
tests/test_modeling_bart.py
View file @
2bd79e23
...
@@ -314,10 +314,16 @@ class BartHeadTests(unittest.TestCase):
...
@@ -314,10 +314,16 @@ class BartHeadTests(unittest.TestCase):
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_generate_fp16
(
self
):
def
test_generate_fp16
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
(
output_past
=
True
)
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
(
output_past
=
True
)
input_ids
=
input_ids
attention_mask
=
input_ids
.
ne
(
1
).
to
(
torch_device
)
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
do_sample
=
False
,
early_stopping
=
True
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_base_model_fp16
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
(
output_past
=
False
)
attention_mask
=
input_ids
.
ne
(
1
).
to
(
torch_device
)
attention_mask
=
input_ids
.
ne
(
1
).
to
(
torch_device
)
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
)
lm_model
(
input_ids
,
attention_mask
=
attention_mask
)
def
test_prepare_bart_decoder_inputs
(
self
):
def
test_prepare_bart_decoder_inputs
(
self
):
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
config
,
*
_
=
self
.
_get_config_and_data
(
output_past
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment