Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c47d2592
Unverified
Commit
c47d2592
authored
Feb 07, 2022
by
Patrick von Platen
Committed by
GitHub
Feb 07, 2022
Browse files
[torch_int_div] Correct true division in generation (#15498)
* [torch_int_div] Correct true division in generation * up * up
parent
5f1918a4
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
51 additions
and
22 deletions
+51
-22
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+5
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+0
-11
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+2
-1
src/transformers/models/sew/modeling_sew.py
src/transformers/models/sew/modeling_sew.py
+2
-1
src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/sew_d/modeling_sew_d.py
+2
-1
src/transformers/models/unispeech/modeling_unispeech.py
src/transformers/models/unispeech/modeling_unispeech.py
+2
-1
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
...ansformers/models/unispeech_sat/modeling_unispeech_sat.py
+2
-1
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+2
-1
src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/wavlm/modeling_wavlm.py
+2
-1
src/transformers/pytorch_utils.py
src/transformers/pytorch_utils.py
+31
-0
No files found.
src/transformers/__init__.py
View file @
c47d2592
...
@@ -1561,6 +1561,7 @@ if is_torch_available():
...
@@ -1561,6 +1561,7 @@ if is_torch_available():
"get_polynomial_decay_schedule_with_warmup"
,
"get_polynomial_decay_schedule_with_warmup"
,
"get_scheduler"
,
"get_scheduler"
,
]
]
_import_structure
[
"pytorch_utils"
]
=
[]
_import_structure
[
"sagemaker"
]
=
[]
_import_structure
[
"sagemaker"
]
=
[]
_import_structure
[
"trainer"
]
=
[
"Trainer"
]
_import_structure
[
"trainer"
]
=
[
"Trainer"
]
_import_structure
[
"trainer_pt_utils"
]
=
[
"torch_distributed_zero_first"
]
_import_structure
[
"trainer_pt_utils"
]
=
[
"torch_distributed_zero_first"
]
...
...
src/transformers/generation_utils.py
View file @
c47d2592
...
@@ -48,6 +48,7 @@ from .generation_stopping_criteria import (
...
@@ -48,6 +48,7 @@ from .generation_stopping_criteria import (
StoppingCriteriaList
,
StoppingCriteriaList
,
validate_stopping_criteria
,
validate_stopping_criteria
,
)
)
from
.pytorch_utils
import
torch_int_div
from
.utils
import
logging
from
.utils
import
logging
...
@@ -2024,7 +2025,7 @@ class GenerationMixin:
...
@@ -2024,7 +2025,7 @@ class GenerationMixin:
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
next_token_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
)
next_indices
=
(
next_tokens
/
vocab_size
)
.
long
()
next_indices
=
torch_int_div
(
next_tokens
,
vocab_size
)
next_tokens
=
next_tokens
%
vocab_size
next_tokens
=
next_tokens
%
vocab_size
# stateless
# stateless
...
@@ -2345,7 +2346,7 @@ class GenerationMixin:
...
@@ -2345,7 +2346,7 @@ class GenerationMixin:
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_token_scores
,
_indices
=
torch
.
sort
(
next_token_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
_indices
)
next_indices
=
next_tokens
//
vocab_size
next_indices
=
torch_int_div
(
next_tokens
,
vocab_size
)
next_tokens
=
next_tokens
%
vocab_size
next_tokens
=
next_tokens
%
vocab_size
# stateless
# stateless
...
@@ -2678,7 +2679,7 @@ class GenerationMixin:
...
@@ -2678,7 +2679,7 @@ class GenerationMixin:
next_token_scores
,
2
*
group_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
next_token_scores
,
2
*
group_size
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
)
next_indices
=
next_tokens
//
vocab_size
next_indices
=
torch_int_div
(
next_tokens
,
vocab_size
)
next_tokens
=
next_tokens
%
vocab_size
next_tokens
=
next_tokens
%
vocab_size
# stateless
# stateless
...
@@ -2706,7 +2707,7 @@ class GenerationMixin:
...
@@ -2706,7 +2707,7 @@ class GenerationMixin:
# (beam_idx // group_size) -> batch_idx
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices
[
batch_group_indices
]
=
(
reordering_indices
[
batch_group_indices
]
=
(
num_beams
*
(
beam_idx
//
group_size
)
+
group_start_idx
+
(
beam_idx
%
group_size
)
num_beams
*
torch_int_div
(
beam_idx
,
group_size
)
+
group_start_idx
+
(
beam_idx
%
group_size
)
)
)
# Store scores, attentions and hidden_states when required
# Store scores, attentions and hidden_states when required
...
...
src/transformers/modeling_utils.py
View file @
c47d2592
...
@@ -23,7 +23,6 @@ from functools import partial
...
@@ -23,7 +23,6 @@ from functools import partial
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
packaging
import
version
from
torch
import
Tensor
,
device
,
nn
from
torch
import
Tensor
,
device
,
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
...
@@ -2463,13 +2462,3 @@ def apply_chunking_to_forward(
...
@@ -2463,13 +2462,3 @@ def apply_chunking_to_forward(
return
torch
.
cat
(
output_chunks
,
dim
=
chunk_dim
)
return
torch
.
cat
(
output_chunks
,
dim
=
chunk_dim
)
return
forward_fn
(
*
input_tensors
)
return
forward_fn
(
*
input_tensors
)
def
torch_int_div
(
tensor1
,
tensor2
):
"""
A function that performs integer division across different versions of PyTorch.
"""
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.8.0"
):
return
tensor1
//
tensor2
else
:
return
torch
.
div
(
tensor1
,
tensor2
,
rounding_mode
=
"floor"
)
src/transformers/models/hubert/modeling_hubert.py
View file @
c47d2592
...
@@ -33,7 +33,8 @@ from ...file_utils import (
...
@@ -33,7 +33,8 @@ from ...file_utils import (
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_hubert
import
HubertConfig
from
.configuration_hubert
import
HubertConfig
...
...
src/transformers/models/sew/modeling_sew.py
View file @
c47d2592
...
@@ -29,7 +29,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
...
@@ -29,7 +29,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_sew
import
SEWConfig
from
.configuration_sew
import
SEWConfig
...
...
src/transformers/models/sew_d/modeling_sew_d.py
View file @
c47d2592
...
@@ -30,7 +30,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
...
@@ -30,7 +30,8 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...file_utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_sew_d
import
SEWDConfig
from
.configuration_sew_d
import
SEWDConfig
...
...
src/transformers/models/unispeech/modeling_unispeech.py
View file @
c47d2592
...
@@ -35,7 +35,8 @@ from ...file_utils import (
...
@@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_unispeech
import
UniSpeechConfig
from
.configuration_unispeech
import
UniSpeechConfig
...
...
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
View file @
c47d2592
...
@@ -35,7 +35,8 @@ from ...file_utils import (
...
@@ -35,7 +35,8 @@ from ...file_utils import (
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_unispeech_sat
import
UniSpeechSatConfig
from
.configuration_unispeech_sat
import
UniSpeechSatConfig
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
c47d2592
...
@@ -41,7 +41,8 @@ from ...modeling_outputs import (
...
@@ -41,7 +41,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_wav2vec2
import
Wav2Vec2Config
from
.configuration_wav2vec2
import
Wav2Vec2Config
...
...
src/transformers/models/wavlm/modeling_wavlm.py
View file @
c47d2592
...
@@ -35,7 +35,8 @@ from ...file_utils import (
...
@@ -35,7 +35,8 @@ from ...file_utils import (
add_start_docstrings_to_model_forward
,
add_start_docstrings_to_model_forward
,
)
)
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
from
...modeling_outputs
import
BaseModelOutput
,
CausalLMOutput
,
SequenceClassifierOutput
,
TokenClassifierOutput
from
...modeling_utils
import
PreTrainedModel
,
torch_int_div
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
torch_int_div
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_wavlm
import
WavLMConfig
from
.configuration_wavlm
import
WavLMConfig
...
...
src/transformers/pytorch_utils.py
0 → 100644
View file @
c47d2592
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
packaging
import
version
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
def
torch_int_div
(
tensor1
,
tensor2
):
"""
A function that performs integer division across different versions of PyTorch.
"""
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.8.0"
):
return
tensor1
//
tensor2
else
:
return
torch
.
div
(
tensor1
,
tensor2
,
rounding_mode
=
"floor"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment