Unverified Commit ea05d671 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix some Flax models' `hidden_states` (#16167)



* fix the last element in `hidden_states`

* fix missing elements in outputs for FlaxWav2Vec2EncoderLayerStableLayerNormCollection
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 88f7c564
...@@ -711,12 +711,19 @@ class FlaxBlenderbotEncoder(nn.Module): ...@@ -711,12 +711,19 @@ class FlaxBlenderbotEncoder(nn.Module):
last_hidden_states = outputs[0] last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states) last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict: if not return_dict:
return (last_hidden_states,) + outputs[1:] outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states, last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states, hidden_states=hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
...@@ -782,12 +789,19 @@ class FlaxBlenderbotDecoder(nn.Module): ...@@ -782,12 +789,19 @@ class FlaxBlenderbotDecoder(nn.Module):
last_hidden_states = outputs[0] last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states) last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict: if not return_dict:
return (last_hidden_states,) + outputs[1:] outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states, last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states, hidden_states=hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
......
...@@ -768,12 +768,19 @@ class FlaxMBartEncoder(nn.Module): ...@@ -768,12 +768,19 @@ class FlaxMBartEncoder(nn.Module):
last_hidden_states = outputs[0] last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states) last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict: if not return_dict:
return (last_hidden_states,) + outputs[1:] outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states, last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states, hidden_states=hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
...@@ -845,12 +852,19 @@ class FlaxMBartDecoder(nn.Module): ...@@ -845,12 +852,19 @@ class FlaxMBartDecoder(nn.Module):
last_hidden_states = outputs[0] last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states) last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict: if not return_dict:
return (last_hidden_states,) + outputs[1:] outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states, last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states, hidden_states=hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
......
...@@ -727,12 +727,19 @@ class FlaxPegasusEncoder(nn.Module): ...@@ -727,12 +727,19 @@ class FlaxPegasusEncoder(nn.Module):
last_hidden_state = outputs[0] last_hidden_state = outputs[0]
last_hidden_state = self.layer_norm(last_hidden_state) last_hidden_state = self.layer_norm(last_hidden_state)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict: if not return_dict:
return (last_hidden_state,) + outputs[1:] outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
hidden_states=outputs.hidden_states, hidden_states=hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
...@@ -796,12 +803,19 @@ class FlaxPegasusDecoder(nn.Module): ...@@ -796,12 +803,19 @@ class FlaxPegasusDecoder(nn.Module):
last_hidden_state = outputs[0] last_hidden_state = outputs[0]
last_hidden_state = self.layer_norm(last_hidden_state) last_hidden_state = self.layer_norm(last_hidden_state)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict: if not return_dict:
return (last_hidden_state,) + outputs[1:] outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
hidden_states=outputs.hidden_states, hidden_states=hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
......
...@@ -631,7 +631,7 @@ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module): ...@@ -631,7 +631,7 @@ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
...@@ -680,13 +680,20 @@ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module): ...@@ -680,13 +680,20 @@ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
return_dict=return_dict, return_dict=return_dict,
) )
hidden_states = self.layer_norm(outputs[0]) last_hidden_state = self.layer_norm(outputs[0])
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict: if not return_dict:
return (hidden_states,) + outputs[1:] outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment