Unverified Commit ea05d671 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix some Flax models' `hidden_states` (#16167)



* fix the last element in `hidden_states`

* fix missing elements in outputs for FlaxWav2Vec2EncoderLayerStableLayerNormCollection
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 88f7c564
......@@ -711,12 +711,19 @@ class FlaxBlenderbotEncoder(nn.Module):
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
return (last_hidden_states,) + outputs[1:]
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
......@@ -782,12 +789,19 @@ class FlaxBlenderbotDecoder(nn.Module):
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
return (last_hidden_states,) + outputs[1:]
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
......
......@@ -768,12 +768,19 @@ class FlaxMBartEncoder(nn.Module):
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
return (last_hidden_states,) + outputs[1:]
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
......@@ -845,12 +852,19 @@ class FlaxMBartDecoder(nn.Module):
last_hidden_states = outputs[0]
last_hidden_states = self.layer_norm(last_hidden_states)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_states,)
if not return_dict:
return (last_hidden_states,) + outputs[1:]
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
......
......@@ -727,12 +727,19 @@ class FlaxPegasusEncoder(nn.Module):
last_hidden_state = outputs[0]
last_hidden_state = self.layer_norm(last_hidden_state)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict:
return (last_hidden_state,) + outputs[1:]
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
......@@ -796,12 +803,19 @@ class FlaxPegasusDecoder(nn.Module):
last_hidden_state = outputs[0]
last_hidden_state = self.layer_norm(last_hidden_state)
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict:
return (last_hidden_state,) + outputs[1:]
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_state,
hidden_states=outputs.hidden_states,
hidden_states=hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
......
......@@ -631,7 +631,7 @@ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
......@@ -680,13 +680,20 @@ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
return_dict=return_dict,
)
hidden_states = self.layer_norm(outputs[0])
last_hidden_state = self.layer_norm(outputs[0])
# update the last element in `hidden_states` after applying `layernorm` above
hidden_states = None
if output_hidden_states:
hidden_states = outputs[1]
hidden_states = hidden_states[:-1] + (last_hidden_state,)
if not return_dict:
return (hidden_states,) + outputs[1:]
outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions
last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment