Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7d65697d
Unverified
Commit
7d65697d
authored
Aug 07, 2023
by
Joao Gante
Committed by
GitHub
Aug 07, 2023
Browse files
Generate: remove Marian hack (#25294)
Remove Marian hack
parent
14510938
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
0 additions
and
58 deletions
+0
-58
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+0
-21
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+0
-18
src/transformers/models/marian/modeling_marian.py
src/transformers/models/marian/modeling_marian.py
+0
-4
src/transformers/models/marian/modeling_tf_marian.py
src/transformers/models/marian/modeling_tf_marian.py
+0
-15
No files found.
src/transformers/generation/tf_utils.py
View file @
7d65697d
...
...
@@ -474,27 +474,6 @@ class TFGenerationMixin:
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
vocab_size
=
getattr
(
self
.
config
,
"vocab_size"
,
None
)
if
vocab_size
is
None
and
self
.
config
.
is_encoder_decoder
:
decoder_config
=
getattr
(
self
.
config
,
"decoder"
,
None
)
if
decoder_config
is
not
None
:
vocab_size
=
getattr
(
self
.
config
.
decoder
,
"vocab_size"
,
None
)
if
cur_len
==
1
and
forced_bos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_bos_token_id
,
-
1e8
,
logits
)
elif
cur_len
==
max_length
-
1
and
forced_eos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_eos_token_id
,
-
1e8
,
logits
)
else
:
return
logits
def
compute_transition_scores
(
self
,
sequences
:
tf
.
Tensor
,
...
...
src/transformers/generation/utils.py
View file @
7d65697d
...
...
@@ -578,12 +578,6 @@ class GenerationMixin:
inputs
=
self
.
_maybe_initialize_input_ids_for_generation
(
inputs
,
bos_token_id
,
model_kwargs
)
return
inputs
,
input_name
,
model_kwargs
def
adjust_logits_during_generation
(
self
,
logits
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
FloatTensor
:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
return
logits
def
_maybe_initialize_input_ids_for_generation
(
self
,
inputs
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -3060,9 +3054,6 @@ class GenerationMixin:
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
@@ -3388,9 +3379,6 @@ class GenerationMixin:
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
@@ -3751,9 +3739,6 @@ class GenerationMixin:
# select outputs of beams of current group only
next_token_logits
=
outputs
.
logits
[
batch_group_indices
,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * group_size, vocab_size)
...
...
@@ -4110,9 +4095,6 @@ class GenerationMixin:
continue
# don't waste resources running the code we don't need
next_token_logits
=
outputs
.
logits
[:,
-
1
,
:]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits
=
self
.
adjust_logits_during_generation
(
next_token_logits
,
cur_len
=
cur_len
)
next_token_scores
=
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
src/transformers/models/marian/modeling_marian.py
View file @
7d65697d
...
...
@@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
):
logits
[:,
self
.
config
.
pad_token_id
]
=
float
(
"-inf"
)
# never predict pad token.
return
logits
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
reordered_past
=
()
...
...
src/transformers/models/marian/modeling_tf_marian.py
View file @
7d65697d
...
...
@@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range
=
tf
.
constant
(
range
(
self
.
config
.
vocab_size
))
logits
=
tf
.
where
(
vocab_range
==
self
.
config
.
pad_token_id
,
LARGE_NEGATIVE
,
logits
)
if
cur_len
==
1
and
forced_bos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
self
.
config
.
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_bos_token_id
,
LARGE_NEGATIVE
,
logits
)
elif
cur_len
==
max_length
-
1
and
forced_eos_token_id
is
not
None
:
vocab_range
=
tf
.
constant
(
range
(
self
.
config
.
vocab_size
))
return
tf
.
where
(
vocab_range
!=
forced_eos_token_id
,
LARGE_NEGATIVE
,
logits
)
else
:
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment