Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4fd89e49
Unverified
Commit
4fd89e49
authored
Jan 03, 2023
by
Joao Gante
Committed by
GitHub
Jan 03, 2023
Browse files
Generate: delete unused TF `_reorder_cache` (#20964)
parent
a3e8d3cb
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
185 deletions
+0
-185
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+0
-4
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+0
-10
src/transformers/models/bert/modeling_tf_bert.py
src/transformers/models/bert/modeling_tf_bert.py
+0
-7
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
+0
-11
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
...s/models/blenderbot_small/modeling_tf_blenderbot_small.py
+0
-11
src/transformers/models/camembert/modeling_tf_camembert.py
src/transformers/models/camembert/modeling_tf_camembert.py
+0
-8
src/transformers/models/ctrl/modeling_tf_ctrl.py
src/transformers/models/ctrl/modeling_tf_ctrl.py
+0
-6
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+0
-4
src/transformers/models/led/modeling_tf_led.py
src/transformers/models/led/modeling_tf_led.py
+0
-10
src/transformers/models/marian/modeling_tf_marian.py
src/transformers/models/marian/modeling_tf_marian.py
+0
-11
src/transformers/models/mbart/modeling_tf_mbart.py
src/transformers/models/mbart/modeling_tf_mbart.py
+0
-11
src/transformers/models/pegasus/modeling_tf_pegasus.py
src/transformers/models/pegasus/modeling_tf_pegasus.py
+0
-11
src/transformers/models/rag/modeling_tf_rag.py
src/transformers/models/rag/modeling_tf_rag.py
+0
-18
src/transformers/models/rembert/modeling_tf_rembert.py
src/transformers/models/rembert/modeling_tf_rembert.py
+0
-8
src/transformers/models/roberta/modeling_tf_roberta.py
src/transformers/models/roberta/modeling_tf_roberta.py
+0
-8
src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py
.../roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py
+0
-8
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
...rmers/models/speech_to_text/modeling_tf_speech_to_text.py
+0
-7
src/transformers/models/t5/modeling_tf_t5.py
src/transformers/models/t5/modeling_tf_t5.py
+0
-24
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
+0
-4
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
...ion_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+0
-4
No files found.
src/transformers/generation/tf_utils.py
View file @
4fd89e49
...
@@ -449,10 +449,6 @@ class TFGenerationMixin:
...
@@ -449,10 +449,6 @@ class TFGenerationMixin:
supports_xla_generation
=
True
supports_xla_generation
=
True
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
return
tuple
(
tf
.
gather
(
layer_past
,
beam_idx
,
axis
=
1
)
for
layer_past
in
past
)
def
adjust_logits_during_generation
(
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
):
...
...
src/transformers/models/bart/modeling_tf_bart.py
View file @
4fd89e49
...
@@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
src/transformers/models/bert/modeling_tf_bert.py
View file @
4fd89e49
...
@@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
)
)
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
@
add_start_docstrings
(
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top."""
,
"""Bert Model with a `next sentence prediction (classification)` head on top."""
,
...
...
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
View file @
4fd89e49
...
@@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
...
@@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
}
}
@
staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
View file @
4fd89e49
...
@@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
...
@@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
}
}
@
staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
src/transformers/models/camembert/modeling_tf_camembert.py
View file @
4fd89e49
...
@@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
...
@@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
return
TFCausalLMOutputWithCrossAttentions
(
return
TFCausalLMOutputWithCrossAttentions
(
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
)
)
@
staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
src/transformers/models/ctrl/modeling_tf_ctrl.py
View file @
4fd89e49
...
@@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return
TFCausalLMOutputWithPast
(
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
)
return
TFCausalLMOutputWithPast
(
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
)
@
staticmethod
def
_reorder_cache
(
past
:
Tuple
[
Tuple
[
tf
.
Tensor
]],
beam_idx
:
tf
.
Tensor
)
->
Tuple
[
Tuple
[
tf
.
Tensor
]]:
return
tuple
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
)
for
layer_past
in
past
)
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
4fd89e49
...
@@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" model.decoder.resize_token_embeddings(...))"
" model.decoder.resize_token_embeddings(...))"
)
)
def
_reorder_cache
(
self
,
past
,
beam_idx
):
# apply decoder cache reordering here
return
self
.
decoder
.
_reorder_cache
(
past
,
beam_idx
)
src/transformers/models/led/modeling_tf_led.py
View file @
4fd89e49
...
@@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
...
@@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
def
hf_compute_loss
(
self
,
labels
,
logits
):
def
hf_compute_loss
(
self
,
labels
,
logits
):
"""CrossEntropyLoss that ignores pad tokens"""
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
loss_fn
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
...
...
src/transformers/models/marian/modeling_tf_marian.py
View file @
4fd89e49
...
@@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
@
staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
def
adjust_logits_during_generation
(
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
):
...
...
src/transformers/models/mbart/modeling_tf_mbart.py
View file @
4fd89e49
...
@@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
...
@@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
)
@
staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
src/transformers/models/pegasus/modeling_tf_pegasus.py
View file @
4fd89e49
...
@@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
...
@@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
@
staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
[:
2
])
+
layer_past
[
2
:],
)
return
reordered_past
src/transformers/models/rag/modeling_tf_rag.py
View file @
4fd89e49
...
@@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def
question_encoder
(
self
):
def
question_encoder
(
self
):
return
self
.
rag
.
question_encoder
return
self
.
rag
.
question_encoder
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def
_reorder_stacked
(
hidden_states
,
new_order
):
n_docs
=
hidden_states
.
shape
[
0
]
//
new_order
.
shape
[
0
]
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
-
1
,
n_docs
,
*
hidden_states
.
shape
[
1
:]))
hidden_states
=
tf
.
gather
(
hidden_states
,
new_order
,
axis
=
0
)
result
=
tf
.
reshape
(
hidden_states
,
(
-
1
,
*
hidden_states
.
shape
[
2
:]))
return
result
reordered_past
=
()
for
layer_past
in
past
:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past
+=
(
tuple
(
_reorder_stacked
(
past_state
,
beam_idx
)
for
past_state
in
layer_past
),)
return
reordered_past
@
staticmethod
@
staticmethod
def
_gather_beams
(
nested
,
beam_indices
,
batch_axis
=
0
):
def
_gather_beams
(
nested
,
beam_indices
,
batch_axis
=
0
):
"""
"""
...
...
src/transformers/models/rembert/modeling_tf_rembert.py
View file @
4fd89e49
...
@@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
...
@@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
)
)
@
staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
src/transformers/models/roberta/modeling_tf_roberta.py
View file @
4fd89e49
...
@@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
...
@@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
)
)
@
staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
class
TFRobertaClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
class
TFRobertaClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
"""Head for sentence-level classification tasks."""
"""Head for sentence-level classification tasks."""
...
...
src/transformers/models/roberta_prelayernorm/modeling_tf_roberta_prelayernorm.py
View file @
4fd89e49
...
@@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
...
@@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
logits
=
output
.
logits
,
past_key_values
=
pkv
,
hidden_states
=
hs
,
attentions
=
attns
,
cross_attentions
=
cross_attns
)
)
@
staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
class
TFRobertaPreLayerNormClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
class
TFRobertaPreLayerNormClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
...
...
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
View file @
4fd89e49
...
@@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
...
@@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
"use_cache"
:
use_cache
,
# change this to avoid caching (presumably for debugging)
}
}
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
tf
.
gather
(
past_state
,
beam_idx
,
axis
=
0
)
for
past_state
in
layer_past
),)
return
reordered_past
src/transformers/models/t5/modeling_tf_t5.py
View file @
4fd89e49
...
@@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
...
@@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
tf
.
Tensor
):
return
self
.
_shift_right
(
labels
)
return
self
.
_shift_right
(
labels
)
def
_reorder_cache
(
self
,
past
,
beam_idx
):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if
past
is
None
:
logger
.
warning
(
"You might want to consider setting `use_cache=True` to speed up decoding"
)
return
past
reordered_decoder_past
=
()
for
layer_past_states
in
past
:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states
=
()
for
layer_past_state
in
layer_past_states
:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states
=
reordered_layer_past_states
+
(
tf
.
gather
(
layer_past_state
,
beam_idx
,
axis
=
0
),
)
assert
reordered_layer_past_states
[
0
].
shape
==
layer_past_states
[
0
].
shape
assert
len
(
reordered_layer_past_states
)
==
len
(
layer_past_states
)
reordered_decoder_past
=
reordered_decoder_past
+
(
reordered_layer_past_states
,)
return
reordered_decoder_past
@
add_start_docstrings
(
@
add_start_docstrings
(
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top."
,
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top."
,
...
...
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
View file @
4fd89e49
...
@@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
...
@@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return
inputs
return
inputs
@
staticmethod
def
_reorder_cache
(
mems
:
List
[
tf
.
Tensor
],
beam_idx
:
tf
.
Tensor
)
->
List
[
tf
.
Tensor
]:
return
[
tf
.
gather
(
layer_past
,
beam_idx
,
axis
=
1
)
for
layer_past
in
mems
]
@
add_start_docstrings
(
@
add_start_docstrings
(
"""
"""
...
...
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
View file @
4fd89e49
...
@@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
...
@@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
)
)
def
_reorder_cache
(
self
,
past
,
beam_idx
):
# apply decoder cache reordering here
return
self
.
decoder
.
_reorder_cache
(
past
,
beam_idx
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment