Unverified Commit a04d4bf2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix flax gpt2 hidden states (#13109)



* Fix inconsistency of the last element in hidden_states between PyTorch/Flax GPT2(Neo) (#13102)

* Fix missing elements in outputs tuple

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Fix local variable 'all_hidden_states' referenced before assignment

* Fix by returning tuple containing None values

* Fix quality
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent d8fb278a
...@@ -24,7 +24,7 @@ from flax.linen.attention import dot_product_attention_weights ...@@ -24,7 +24,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging from ...utils import logging
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
...@@ -458,20 +458,10 @@ class FlaxGPT2BlockCollection(nn.Module): ...@@ -458,20 +458,10 @@ class FlaxGPT2BlockCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if output_hidden_states: # this contains possible `None` values - `FlaxGPT2Module` will filter them out
all_hidden_states += (hidden_states,) outputs = (hidden_states, all_hidden_states, all_attentions)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPast( return outputs
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class FlaxGPT2Module(nn.Module): class FlaxGPT2Module(nn.Module):
...@@ -527,13 +517,19 @@ class FlaxGPT2Module(nn.Module): ...@@ -527,13 +517,19 @@ class FlaxGPT2Module(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict: if not return_dict:
return (hidden_states,) + outputs[1:] return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
hidden_states=outputs.hidden_states, hidden_states=outputs[1],
attentions=outputs.attentions, attentions=outputs[-1],
) )
......
...@@ -25,7 +25,7 @@ from flax.linen.attention import dot_product_attention_weights ...@@ -25,7 +25,7 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging from ...utils import logging
from .configuration_gpt_neo import GPTNeoConfig from .configuration_gpt_neo import GPTNeoConfig
...@@ -488,20 +488,10 @@ class FlaxGPTNeoBlockCollection(nn.Module): ...@@ -488,20 +488,10 @@ class FlaxGPTNeoBlockCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if output_hidden_states: # this contains possible `None` values - `FlaxGPTNeoModule` will filter them out
all_hidden_states += (hidden_states,) outputs = (hidden_states, all_hidden_states, all_attentions)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPast( return outputs
last_hidden_state=hidden_states,
past_key_values=None,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class FlaxGPTNeoModule(nn.Module): class FlaxGPTNeoModule(nn.Module):
...@@ -557,13 +547,22 @@ class FlaxGPTNeoModule(nn.Module): ...@@ -557,13 +547,22 @@ class FlaxGPTNeoModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = outputs[0]
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict: if not return_dict:
return (hidden_states,) + outputs[1:] return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
hidden_states=outputs.hidden_states, hidden_states=outputs[1],
attentions=outputs.attentions, attentions=outputs[-1],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment