Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c1aa0edb
Unverified
Commit
c1aa0edb
authored
Aug 02, 2024
by
Sanchit Gandhi
Committed by
GitHub
Aug 02, 2024
Browse files
[generate] only require an attention mask for mps with torch<2.4 (#32367)
* up * style * stopping
parent
083e13b7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
4 deletions
+9
-4
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+4
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+4
-3
src/transformers/pytorch_utils.py
src/transformers/pytorch_utils.py
+1
-0
No files found.
src/transformers/generation/stopping_criteria.py
View file @
c1aa0edb
...
@@ -9,6 +9,8 @@ import numpy as np
...
@@ -9,6 +9,8 @@ import numpy as np
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
transformers.pytorch_utils
import
is_torch_greater_or_equal_than_2_4
from
..tokenization_utils_base
import
PreTrainedTokenizerBase
from
..tokenization_utils_base
import
PreTrainedTokenizerBase
from
..utils
import
add_start_docstrings
,
logging
from
..utils
import
add_start_docstrings
,
logging
...
@@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria):
...
@@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria):
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
self
.
eos_token_id
=
self
.
eos_token_id
.
to
(
input_ids
.
device
)
self
.
eos_token_id
=
self
.
eos_token_id
.
to
(
input_ids
.
device
)
if
input_ids
.
device
.
type
==
"mps"
:
if
input_ids
.
device
.
type
==
"mps"
and
not
is_torch_greater_or_equal_than_2_4
:
# TODO: remove this workaround when we stop supporting torch<=2.3
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done
=
(
is_done
=
(
input_ids
[:,
-
1
]
input_ids
[:,
-
1
]
...
...
src/transformers/generation/utils.py
View file @
c1aa0edb
...
@@ -47,6 +47,7 @@ from ..models.auto import (
...
@@ -47,6 +47,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
MODEL_FOR_VISION_2_SEQ_MAPPING
,
)
)
from
..pytorch_utils
import
is_torch_greater_or_equal_than_2_4
from
..tokenization_utils
import
ExtensionsTrie
from
..tokenization_utils
import
ExtensionsTrie
from
..utils
import
(
from
..utils
import
(
ModelOutput
,
ModelOutput
,
...
@@ -488,10 +489,10 @@ class GenerationMixin:
...
@@ -488,10 +489,10 @@ class GenerationMixin:
return
default_attention_mask
return
default_attention_mask
# Otherwise we have may have information -> try to infer the attention mask
# Otherwise we have may have information -> try to infer the attention mask
if
inputs
.
device
.
type
==
"mps"
:
if
inputs
.
device
.
type
==
"mps"
and
not
is_torch_greater_or_equal_than_2_4
:
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
# mps does not support torch.isin
for torch<2.4
(https://github.com/pytorch/pytorch/issues/77764)
raise
ValueError
(
raise
ValueError
(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or u
se a different device.
"
"Can't infer missing attention mask on `mps` device
for torch<2.4
. Please provide an `attention_mask` or u
pgrade to torch>=2.4
"
)
)
is_pad_token_in_inputs
=
(
pad_token_id
is
not
None
)
and
(
is_pad_token_in_inputs
=
(
pad_token_id
is
not
None
)
and
(
...
...
src/transformers/pytorch_utils.py
View file @
c1aa0edb
...
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
...
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
parsed_torch_version_base
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
is_torch_greater_or_equal_than_2_4
=
parsed_torch_version_base
>=
version
.
parse
(
"2.4"
)
is_torch_greater_or_equal_than_2_3
=
parsed_torch_version_base
>=
version
.
parse
(
"2.3"
)
is_torch_greater_or_equal_than_2_3
=
parsed_torch_version_base
>=
version
.
parse
(
"2.3"
)
is_torch_greater_or_equal_than_2_2
=
parsed_torch_version_base
>=
version
.
parse
(
"2.2"
)
is_torch_greater_or_equal_than_2_2
=
parsed_torch_version_base
>=
version
.
parse
(
"2.2"
)
is_torch_greater_or_equal_than_2_1
=
parsed_torch_version_base
>=
version
.
parse
(
"2.1"
)
is_torch_greater_or_equal_than_2_1
=
parsed_torch_version_base
>=
version
.
parse
(
"2.1"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment