Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c47d2592
Unverified
Commit
c47d2592
authored
Feb 07, 2022
by
Patrick von Platen
Committed by
GitHub
Feb 07, 2022
Browse files
[torch_int_div] Correct true division in generation (#15498)
* [torch_int_div] Correct true division in generation * up * up
parent
5f1918a4
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
51 additions
and
22 deletions
+51
-22
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+5
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+0
-11
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+2
-1
src/transformers/models/sew/modeling_sew.py
src/transformers/models/sew/modeling_sew.py
+2
-1
src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/sew_d/modeling_sew_d.py
+2
-1
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech/modeling_unispeech.py
+2
-1
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
...ansformers/models/unispeech_sat/modeling_unispeech_sat.py
+2
-1
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+2
-1
src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/wavlm/modeling_wavlm.py
+2
-1
src/transformers/pytorch_utils.py
src/transformers/pytorch_utils.py
+31
-0
No files found.
src/transformers/__init__.py
View file @
c47d2592
...
...
@@ -1561,6 +1561,7 @@ if is_torch_available():
"get_polynomial_decay_schedule_with_warmup"
,
"get_scheduler"
,
]
_import_structure
[
"pytorch_utils"
]
=
[]
_import_structure
[
"sagemaker"
]
=
[]
_import_structure
[
"trainer"
]
=
[
"Trainer"
]
_import_structure
[
"trainer_pt_utils"
]
=
[
"torch_distributed_zero_first"
]
...
...
src/transformers/generation_utils.py
View file @
c47d2592
...
...
@@ -48,6 +48,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList
,
validate_stopping_criteria
,
)
from
.pytorch_utils
import
torch_int_div
from
.utils
import
logging
...
...
@@ -2024,7 +2025,7 @@ class GenerationMixin:
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_indices
=
(
next_tokens
/
vocab_size
)
.
long
()
next_indices
=
torch_int_div
(
next_tokens
,
vocab_size
)
next_tokens
=
next_tokens
%
vocab_size
# stateless
...
...
@@ -2345,7 +2346,7 @@ class GenerationMixin:
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
next_indices
=
next_tokens
//
vocab_size
next_indices
=
torch_int_div
(
next_tokens
,
vocab_size
)
next_tokens
=
next_tokens
%
vocab_size
# stateless
...
...
@@ -2678,7 +2679,7 @@ class GenerationMixin:
next_token_scores
,
2
*
group_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_indices
=
next_tokens
//
vocab_size
next_indices
=
torch_int_div
(
next_tokens
,
vocab_size
)
next_tokens
=
next_tokens
%
vocab_size
# stateless
...
...
@@ -2706,7 +2707,7 @@ class GenerationMixin:
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices
[
batch_group_indices
]
=
(
num_beams
*
(
beam_idx
//
group_size
)
+
group_start_idx
+
(
beam_idx
%
group_size
)
num_beams
*
torch_int_div
(
beam_idx
,
group_size
)
+
group_start_idx
+
(
beam_idx
%
group_size
)
)
# Store scores, attentions and hidden_states when required
...
...
src/transformers/modeling_utils.py
View file @
c47d2592
...
...
@@ -23,7 +23,6 @@ from functools import partial
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
packaging
import
version
from
torch
import
Tensor
,
device
,
nn
from
torch.nn
import
CrossEntropyLoss
...
...
@@ -2463,13 +2462,3 @@ def apply_chunking_to_forward(
return
torch
.
cat
(
output_chunks
,
dim
=
chunk_dim
)
return
forward_fn
(
*
input_tensors
)
def
torch_int_div
(
tensor1
,
tensor2
):
"""
A function that performs integer division across different versions of PyTorch.
"""
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.8.0"
):
return
tensor1
//
tensor2
else
:
return
torch
.
div
(
tensor1
,
tensor2
,
rounding_mode
=
"floor"
)
src/transformers/models/hubert/modeling_hubert.py
View file @
c47d2592
...
...
@@ -33,7 +33,8 @@ from ...file_utils import (
replace_return_docstrings
,
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_hubert
import
HubertConfig
...
...
src/transformers/models/sew/modeling_sew.py
View file @
c47d2592
...
...
@@ -29,7 +29,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from
...activations
import
ACT2FN
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_sew
import
SEWConfig
...
...
src/transformers/models/sew_d/modeling_sew_d.py
View file @
c47d2592
...
...
@@ -30,7 +30,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from
...activations
import
ACT2FN
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_sew_d
import
SEWDConfig
...
...
src/transformers/models/unispeech/modeling_unispeech.py
View file @
c47d2592
...
...
@@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings
,
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_unispeech
import
UniSpeechConfig
...
...
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
View file @
c47d2592
...
...
@@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings
,
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_unispeech_sat
import
UniSpeechSatConfig
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
c47d2592
...
...
@@ -41,7 +41,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput
,
TokenClassifierOutput
,
)
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_wav2vec2
import
Wav2Vec2Config
...
...
src/transformers/models/wavlm/modeling_wavlm.py
View file @
c47d2592
...
...
@@ -35,7 +35,8 @@ from ...file_utils import (
add_start_docstrings_to_model_forward
,
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
.configuration_wavlm
import
WavLMConfig
...
...
src/transformers/pytorch_utils.py
0 → 100644
View file @
c47d2592
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
packaging
import
version
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
def
torch_int_div
(
tensor1
,
tensor2
):
"""
A function that performs integer division across different versions of PyTorch.
"""
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.8.0"
):
return
tensor1
//
tensor2
else
:
return
torch
.
div
(
tensor1
,
tensor2
,
rounding_mode
=
"floor"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment