"docs/source/en/tasks/sequence_classification.md" did not exist on "57f25f4b7fb85ff069f8701372710b2a3207bf2d"
Unverified Commit c1aa0edb authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[generate] only require an attention mask for mps with torch<2.4 (#32367)

* up

* style

* stopping
parent 083e13b7
...@@ -9,6 +9,8 @@ import numpy as np ...@@ -9,6 +9,8 @@ import numpy as np
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..tokenization_utils_base import PreTrainedTokenizerBase from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import add_start_docstrings, logging from ..utils import add_start_docstrings, logging
...@@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria): ...@@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device) self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps": if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# TODO: remove this workaround when we stop supporting torch<=2.3
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075 # https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = ( is_done = (
input_ids[:, -1] input_ids[:, -1]
......
...@@ -47,6 +47,7 @@ from ..models.auto import ( ...@@ -47,6 +47,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING,
) )
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..tokenization_utils import ExtensionsTrie from ..tokenization_utils import ExtensionsTrie
from ..utils import ( from ..utils import (
ModelOutput, ModelOutput,
...@@ -488,10 +489,10 @@ class GenerationMixin: ...@@ -488,10 +489,10 @@ class GenerationMixin:
return default_attention_mask return default_attention_mask
# Otherwise we have may have information -> try to infer the attention mask # Otherwise we have may have information -> try to infer the attention mask
if inputs.device.type == "mps": if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764) # mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
raise ValueError( raise ValueError(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device." "Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
) )
is_pad_token_in_inputs = (pad_token_id is not None) and ( is_pad_token_in_inputs = (pad_token_id is not None) and (
......
...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) ...@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3") is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2") is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment