Unverified Commit a6752a7d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `serving_output` for TF composite models (encoder-decoder like models) (#22743)



* fix

* style

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 410b61ad
......@@ -633,14 +633,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
dec_hs = (
tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
)
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
enc_hs = (
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
)
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
cross_attns = (
tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions and output.cross_attentions is not None
if self.config.decoder.output_attentions and output.cross_attentions is not None
else None
)
......
......@@ -662,14 +662,18 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
)
def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
dec_hs = (
tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
)
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
enc_hs = (
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
)
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
cross_attns = (
tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions and output.cross_attentions is not None
if self.config.decoder.output_attentions and output.cross_attentions is not None
else None
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment