Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d444edb3
Unverified
Commit
d444edb3
authored
Jun 29, 2022
by
Younes Belkada
Committed by
GitHub
Jun 29, 2022
Browse files
OPT - Fix Softmax NaN in half precision mode (#17437)
parent
9fe2403b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
4 deletions
+31
-4
src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_opt.py
+8
-3
tests/models/opt/test_modeling_opt.py
tests/models/opt/test_modeling_opt.py
+23
-1
No files found.
src/transformers/models/opt/modeling_opt.py
View file @
d444edb3
...
...
@@ -109,7 +109,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
return
super
().
forward
(
positions
+
self
.
offset
)
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->OPT
class
OPTAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
...
@@ -212,8 +211,14 @@ class OPTAttention(nn.Module):
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attention_mask
attn_weights
=
torch
.
max
(
attn_weights
,
torch
.
tensor
(
torch
.
finfo
(
attn_weights
.
dtype
).
min
))
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
dtype_attn_weights
=
attn_weights
.
dtype
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if
dtype_attn_weights
==
torch
.
float16
:
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
dtype_attn_weights
)
else
:
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
)
if
layer_head_mask
is
not
None
:
...
...
@@ -382,7 +387,7 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"OPTDecoderLayer"
]
_keys_to_ignore_on_load_unexpected
=
[
r
"decoder.version"
]
_keys_to_ignore_on_load_unexpected
=
[
r
"decoder
\
.version"
]
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
init_std
...
...
tests/models/opt/test_modeling_opt.py
View file @
d444edb3
...
...
@@ -22,7 +22,7 @@ import unittest
import
timeout_decorator
# noqa
from
transformers
import
OPTConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
require_torch_gpu
,
slow
,
torch_device
from
...generation.test_generation_utils
import
GenerationTesterMixin
from
...test_configuration_common
import
ConfigTester
...
...
@@ -428,3 +428,25 @@ class OPTGenerationTest(unittest.TestCase):
predicted_outputs
+=
generated_string
self
.
assertListEqual
(
predicted_outputs
,
EXPECTED_OUTPUTS
)
@
require_torch_gpu
def
test_batched_nan_fp16
(
self
):
# a bug manifested starting at models facebook/opt-1.3 and larger when running batched generations,
# therefore not using a tiny model, but the smallest model the problem was seen with which is opt-1.3b.
# please refer to this github thread: https://github.com/huggingface/transformers/pull/17437 for more details
model_name
=
"facebook/opt-1.3b"
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
padding_side
=
"left"
)
model
=
OPTForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch
.
float16
,
use_cache
=
True
).
cuda
()
model
=
model
.
eval
()
batch
=
tokenizer
([
"Who are you?"
,
"Joe Biden is the president of"
],
padding
=
True
,
return_tensors
=
"pt"
)
input_ids
=
batch
[
"input_ids"
].
cuda
()
attention_mask
=
batch
[
"attention_mask"
].
cuda
()
with
torch
.
no_grad
():
outputs
=
model
(
input_ids
,
attention_mask
=
attention_mask
)
self
.
assertFalse
(
torch
.
isnan
(
outputs
.
logits
[
0
]).
any
().
item
()
)
# the first logits could contain NaNs if it fails
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment