Unverified Commit c47d2592 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[torch_int_div] Correct true division in generation (#15498)

* [torch_int_div] Correct true division in generation

* up

* up
parent 5f1918a4
......@@ -1561,6 +1561,7 @@ if is_torch_available():
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
]
_import_structure["pytorch_utils"] = []
_import_structure["sagemaker"] = []
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
......
......@@ -48,6 +48,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList,
validate_stopping_criteria,
)
from .pytorch_utils import torch_int_div
from .utils import logging
......@@ -2024,7 +2025,7 @@ class GenerationMixin:
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = (next_tokens / vocab_size).long()
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
......@@ -2345,7 +2346,7 @@ class GenerationMixin:
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = next_tokens // vocab_size
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
......@@ -2678,7 +2679,7 @@ class GenerationMixin:
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = next_tokens // vocab_size
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
......@@ -2706,7 +2707,7 @@ class GenerationMixin:
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * (beam_idx // group_size) + group_start_idx + (beam_idx % group_size)
num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
)
# Store scores, attentions and hidden_states when required
......
......@@ -23,7 +23,6 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
......@@ -2463,13 +2462,3 @@ def apply_chunking_to_forward(
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")
......@@ -33,7 +33,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_hubert import HubertConfig
......
......@@ -29,7 +29,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_sew import SEWConfig
......
......@@ -30,7 +30,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_sew_d import SEWDConfig
......
......@@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_unispeech import UniSpeechConfig
......
......@@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_unispeech_sat import UniSpeechSatConfig
......
......@@ -41,7 +41,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
......
......@@ -35,7 +35,8 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import logging
from .configuration_wavlm import WavLMConfig
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging import version
from .utils import logging
logger = logging.get_logger(__name__)
def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment