Unverified Commit 25848086 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

update serving_output for some TF models (#15568)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 315e6740
...@@ -2338,6 +2338,7 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2338,6 +2338,7 @@ class TFLEDModel(TFLEDPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None
...@@ -2347,6 +2348,7 @@ class TFLEDModel(TFLEDPreTrainedModel): ...@@ -2347,6 +2348,7 @@ class TFLEDModel(TFLEDPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
...@@ -2494,7 +2496,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2494,7 +2496,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
past_key_values=outputs.past_key_values, # index 1 of d outputs past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions, # index 4 of d outputs
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out
...@@ -2505,6 +2507,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2505,6 +2507,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None
...@@ -2514,6 +2517,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2514,6 +2517,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -1280,6 +1280,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1280,6 +1280,7 @@ class TFT5Model(TFT5PreTrainedModel):
pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1289,6 +1290,7 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1289,6 +1290,7 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
cross_attentions=cross_attns,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
) )
...@@ -1525,6 +1527,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1525,6 +1527,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
...@@ -1533,6 +1536,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1533,6 +1536,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
past_key_values=pkv, past_key_values=pkv,
decoder_hidden_states=dec_hs, decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns, decoder_attentions=dec_attns,
cross_attentions=cross_attns,
encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_last_hidden_state=output.encoder_last_hidden_state,
encoder_hidden_states=enc_hs, encoder_hidden_states=enc_hs,
encoder_attentions=enc_attns, encoder_attentions=enc_attns,
......
...@@ -1514,7 +1514,12 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): ...@@ -1514,7 +1514,12 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) return TFWav2Vec2BaseModelOutput(
last_hidden_state=output.last_hidden_state,
extract_features=output.extract_features,
hidden_states=hs,
attentions=attns,
)
@add_start_docstrings( @add_start_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment