Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c1aa0edb
Unverified
Commit
c1aa0edb
authored
Aug 02, 2024
by
Sanchit Gandhi
Committed by
GitHub
Aug 02, 2024
Browse files
[generate] only require an attention mask for mps with torch<2.4 (#32367)
* up * style * stopping
parent
083e13b7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
4 deletions
+9
-4
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+4
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+4
-3
src/transformers/pytorch_utils.py
src/transformers/pytorch_utils.py
+1
-0
No files found.
src/transformers/generation/stopping_criteria.py
View file @
c1aa0edb
...
...
@@ -9,6 +9,8 @@ import numpy as np
import
torch
from
torch.nn
import
functional
as
F
from
transformers.pytorch_utils
import
is_torch_greater_or_equal_than_2_4
from
..tokenization_utils_base
import
PreTrainedTokenizerBase
from
..utils
import
add_start_docstrings
,
logging
...
...
@@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria):
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
self
.
eos_token_id
=
self
.
eos_token_id
.
to
(
input_ids
.
device
)
if
input_ids
.
device
.
type
==
"mps"
:
if
input_ids
.
device
.
type
==
"mps"
and
not
is_torch_greater_or_equal_than_2_4
:
# TODO: remove this workaround when we stop supporting torch<=2.3
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done
=
(
input_ids
[:,
-
1
]
...
...
src/transformers/generation/utils.py
View file @
c1aa0edb
...
...
@@ -47,6 +47,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
from
..pytorch_utils
import
is_torch_greater_or_equal_than_2_4
from
..tokenization_utils
import
ExtensionsTrie
from
..utils
import
(
ModelOutput
,
...
...
@@ -488,10 +489,10 @@ class GenerationMixin:
return
default_attention_mask
# Otherwise we have may have information -> try to infer the attention mask
if
inputs
.
device
.
type
==
"mps"
:
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
if
inputs
.
device
.
type
==
"mps"
and
not
is_torch_greater_or_equal_than_2_4
:
# mps does not support torch.isin
for torch<2.4
(https://github.com/pytorch/pytorch/issues/77764)
raise
ValueError
(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or u
se a different device.
"
"Can't infer missing attention mask on `mps` device
for torch<2.4
. Please provide an `attention_mask` or u
pgrade to torch>=2.4
"
)
is_pad_token_in_inputs
=
(
pad_token_id
is
not
None
)
and
(
...
...
src/transformers/pytorch_utils.py
View file @
c1aa0edb
...
...
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
is_torch_greater_or_equal_than_2_4
=
parsed_torch_version_base
>=
version
.
parse
(
"2.4"
)
is_torch_greater_or_equal_than_2_3
=
parsed_torch_version_base
>=
version
.
parse
(
"2.3"
)
is_torch_greater_or_equal_than_2_2
=
parsed_torch_version_base
>=
version
.
parse
(
"2.2"
)
is_torch_greater_or_equal_than_2_1
=
parsed_torch_version_base
>=
version
.
parse
(
"2.1"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment