Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4fd89e49
Unverified
Commit
4fd89e49
authored
Jan 03, 2023
by
Joao Gante
Committed by
GitHub
Jan 03, 2023
Browse files
Generate: delete unused TF `_reorder_cache` (#20964)
parent
a3e8d3cb
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
22 deletions
+0
-22
src/transformers/models/whisper/modeling_tf_whisper.py
src/transformers/models/whisper/modeling_tf_whisper.py
+0
-8
src/transformers/models/xglm/modeling_tf_xglm.py
src/transformers/models/xglm/modeling_tf_xglm.py
+0
-7
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
...ame}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
+0
-7
No files found.
src/transformers/models/whisper/modeling_tf_whisper.py
View file @
4fd89e49
...
@@ -1386,11 +1386,3 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
...
@@ -1386,11 +1386,3 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
"decoder_attention_mask"
:
decoder_attention_mask
,
"decoder_attention_mask"
:
decoder_attention_mask
,
"decoder_position_ids"
:
decoder_position_ids
,
"decoder_position_ids"
:
decoder_position_ids
,
}
}
#
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
)
for
past_state
in
layer_past
),)
return
reordered_past
src/transformers/models/xglm/modeling_tf_xglm.py
View file @
4fd89e49
...
@@ -992,10 +992,3 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -992,10 +992,3 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
attentions
=
attns
,
attentions
=
attns
,
cross_attentions
=
cross_attns
,
cross_attentions
=
cross_attns
,
)
)
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
View file @
4fd89e49
...
@@ -3028,13 +3028,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
...
@@ -3028,13 +3028,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
}
}
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
def
hf_compute_loss
(
self
,
labels
,
logits
):
def
hf_compute_loss
(
self
,
labels
,
logits
):
"""CrossEntropyLoss that ignores pad tokens"""
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
loss_fn
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment