"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "98717cb34110c35f6c6b65b8d5a4b9932cf3dd98"
Unverified Commit a6752a7d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `serving_output` for TF composite models (encoder-decoder like models) (#22743)



* fix

* style

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 410b61ad
...@@ -633,14 +633,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -633,14 +633,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def serving_output(self, output): def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = (
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None )
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
enc_hs = (
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
)
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
cross_attns = ( cross_attns = (
tf.convert_to_tensor(output.cross_attentions) tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions and output.cross_attentions is not None if self.config.decoder.output_attentions and output.cross_attentions is not None
else None else None
) )
......
...@@ -662,14 +662,18 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -662,14 +662,18 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
) )
def serving_output(self, output): def serving_output(self, output):
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_hs = (
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None )
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
enc_hs = (
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
)
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
cross_attns = ( cross_attns = (
tf.convert_to_tensor(output.cross_attentions) tf.convert_to_tensor(output.cross_attentions)
if self.config.output_attentions and output.cross_attentions is not None if self.config.decoder.output_attentions and output.cross_attentions is not None
else None else None
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment