Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1567bef3
Unverified
Commit
1567bef3
authored
Feb 15, 2023
by
Joao Gante
Committed by
GitHub
Feb 15, 2023
Browse files
Generate: PT Dynamo without graph breaks in the main greedy/sample loop (#21648)
parent
7a5533b2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
7 deletions
+10
-7
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+3
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+7
-7
No files found.
src/transformers/generation/configuration_utils.py
View file @
1567bef3
...
...
@@ -298,6 +298,9 @@ class GenerationConfig(PushToHubMixin):
self
.
validate
()
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
GenerationConfig
):
return
False
self_dict
=
self
.
__dict__
.
copy
()
other_dict
=
other
.
__dict__
.
copy
()
# ignore metadata
...
...
src/transformers/modeling_utils.py
View file @
1567bef3
...
...
@@ -190,14 +190,14 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
if
is_torch_tpu_available
():
if
XLA_USE_BF16
in
ENV_VARS_TRUE_VALUES
:
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
if
XLA_USE_BF16
in
ENV_VARS_TRUE_VALUES
and
is_torch_tpu_available
():
return
torch
.
bfloat16
if
XLA_DOWNCAST_BF16
in
ENV_VARS_TRUE_VALUES
and
is_torch_tpu_available
():
if
t
.
dtype
==
torch
.
float
:
return
torch
.
bfloat16
if
XLA_DOWNCAST_BF16
in
ENV_VARS_TRUE_VALUES
:
if
t
.
dtype
==
torch
.
float
:
return
torch
.
bfloat16
if
t
.
dtype
==
torch
.
double
:
return
torch
.
float32
if
t
.
dtype
==
torch
.
double
:
return
torch
.
float32
return
t
.
dtype
if
last_dtype
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment