Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d65697d
Unverified
Commit
7d65697d
authored
Aug 07, 2023
by
Joao Gante
Committed by
GitHub
Aug 07, 2023
Browse files
Generate: remove Marian hack (#25294)
Remove Marian hack
parent
14510938
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
0 additions
and
58 deletions
+0
-58
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+0
-21
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+0
-18
src/transformers/models/marian/modeling_marian.py
src/transformers/models/marian/modeling_marian.py
+0
-4
src/transformers/models/marian/modeling_tf_marian.py
src/transformers/models/marian/modeling_tf_marian.py
+0
-15
No files found.
src/transformers/generation/tf_utils.py
View file @
7d65697d
...
@@ -474,27 +474,6 @@ class TFGenerationMixin:
...
@@ -474,27 +474,6 @@ class TFGenerationMixin:
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
)
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
vocab_size
=
getattr
(
self
.
config
,
"vocab_size"
,
None
)
if
vocab_size
is
None
and
self
.
config
.
is_encoder_decoder
:
decoder_config
=
getattr
(
self
.
config
,
"decoder"
,
None
)
if
decoder_config
is
not
None
:
vocab_size
=
getattr
(
self
.
config
.
decoder
,
"vocab_size"
,
None
)
if
cur_len
==
1
and
forced_bos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_bos_token_id
,
-
1e8
,
logits
)
elif
cur_len
==
max_length
-
1
and
forced_eos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_eos_token_id
,
-
1e8
,
logits
)
else
:
return
logits
def
compute_transition_scores
(
def
compute_transition_scores
(
self
,
self
,
sequences
:
tf
.
Tensor
,
sequences
:
tf
.
Tensor
,
...
...
src/transformers/generation/utils.py
View file @
7d65697d
...
@@ -578,12 +578,6 @@ class GenerationMixin:
...
@@ -578,12 +578,6 @@ class GenerationMixin:
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
)
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
)
return
inputs
,
input_name
,
model_kwargs
return
inputs
,
input_name
,
model_kwargs
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
return
logits
def
_maybe_initialize_input_ids_for_generation
(
def
_maybe_initialize_input_ids_for_generation
(
self
,
self
,
inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -3060,9 +3054,6 @@ class GenerationMixin:
...
@@ -3060,9 +3054,6 @@ class GenerationMixin:
continue
# don't waste resources running the code we don't need
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
...
@@ -3388,9 +3379,6 @@ class GenerationMixin:
...
@@ -3388,9 +3379,6 @@ class GenerationMixin:
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
...
@@ -3751,9 +3739,6 @@ class GenerationMixin:
...
@@ -3751,9 +3739,6 @@ class GenerationMixin:
# select outputs of beams of current group only
# select outputs of beams of current group only
next_token_logits
=
outputs
.
logits
[
batch_group_indices
,
-
1
,
:]
next_token_logits
=
outputs
.
logits
[
batch_group_indices
,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
next_token_logits
,
dim
=-
1
)
# (batch_size * group_size, vocab_size)
)
# (batch_size * group_size, vocab_size)
...
@@ -4110,9 +4095,6 @@ class GenerationMixin:
...
@@ -4110,9 +4095,6 @@ class GenerationMixin:
continue
# don't waste resources running the code we don't need
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
)
# (batch_size * num_beams, vocab_size)
...
...
src/transformers/models/marian/modeling_marian.py
View file @
7d65697d
...
@@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel):
...
@@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
):
logits
[:,
self
.
config
.
pad_token_id
]
=
float
(
"-inf"
)
# never predict pad token.
return
logits
@
staticmethod
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
def
_reorder_cache
(
past_key_values
,
beam_idx
):
reordered_past
=
()
reordered_past
=
()
...
...
src/transformers/models/marian/modeling_tf_marian.py
View file @
7d65697d
...
@@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range
=
tf
.
constant
(
range
(
self
.
config
.
vocab_size
))
logits
=
tf
.
where
(
vocab_range
==
self
.
config
.
pad_token_id
,
LARGE_NEGATIVE
,
logits
)
if
cur_len
==
1
and
forced_bos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
self
.
config
.
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_bos_token_id
,
LARGE_NEGATIVE
,
logits
)
elif
cur_len
==
max_length
-
1
and
forced_eos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
self
.
config
.
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_eos_token_id
,
LARGE_NEGATIVE
,
logits
)
else
:
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment