Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1741d740
"vscode:/vscode.git/clone" did not exist on "3cd014195771adfed18704dadc59f5a7a5c069b0"
Unverified
Commit
1741d740
authored
Mar 05, 2020
by
Thomas Wolf
Committed by
GitHub
Mar 05, 2020
Browse files
Merge pull request #3145 from sshleifer/bartfp16
[Bart] FP16 Support
parents
bbabbc16
14d40584
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
4 deletions
+11
-4
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+4
-4
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+7
-0
No files found.
src/transformers/modeling_bart.py
View file @
1741d740
...
@@ -640,9 +640,9 @@ class SelfAttention(nn.Module):
...
@@ -640,9 +640,9 @@ class SelfAttention(nn.Module):
reshaped
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
)
reshaped
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
)
attn_weights
=
attn_weights
.
masked_fill
(
reshaped
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
masked_fill
(
reshaped
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
_float
=
F
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_
weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_
probs
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
,
)
attn_probs
=
F
.
dropout
(
attn_weights_float
,
p
=
self
.
dropout
,
training
=
self
.
training
,)
assert
v
is
not
None
assert
v
is
not
None
attn_output
=
torch
.
bmm
(
attn_probs
,
v
)
attn_output
=
torch
.
bmm
(
attn_probs
,
v
)
assert
attn_output
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
assert
attn_output
.
size
()
==
(
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
)
...
@@ -696,7 +696,7 @@ class SelfAttention(nn.Module):
...
@@ -696,7 +696,7 @@ class SelfAttention(nn.Module):
elif
prev_key_padding_mask
is
not
None
:
elif
prev_key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
prev_key_padding_mask
.
size
(
1
))
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
prev_key_padding_mask
.
size
(
1
))
if
prev_key_padding_mask
.
is_cuda
:
if
prev_key_padding_mask
.
is_cuda
:
filler
=
filler
.
cuda
(
)
filler
=
filler
.
to
(
prev_key_padding_mask
.
device
)
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
new_key_padding_mask
=
torch
.
cat
([
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
))
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
))
...
...
tests/test_modeling_bart.py
View file @
1741d740
...
@@ -294,6 +294,13 @@ class BartHeadTests(unittest.TestCase):
...
@@ -294,6 +294,13 @@ class BartHeadTests(unittest.TestCase):
bart_toks
=
tokenizer
.
encode
(
ex
,
return_tensors
=
"pt"
)
bart_toks
=
tokenizer
.
encode
(
ex
,
return_tensors
=
"pt"
)
_assert_tensors_equal
(
desired_result
.
long
(),
bart_toks
,
prefix
=
ex
)
_assert_tensors_equal
(
desired_result
.
long
(),
bart_toks
,
prefix
=
ex
)
@
unittest
.
skipIf
(
torch_device
==
"cpu"
,
"Cant do half precision"
)
def
test_generate_fp16
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
(
output_past
=
True
)
attention_mask
=
input_ids
.
ne
(
1
)
lm_model
=
BartForMaskedLM
(
config
).
eval
().
to
(
torch_device
).
half
()
lm_model
.
generate
(
input_ids
,
attention_mask
)
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment