Unverified Commit a6885db9 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax tests] fix test_model_outputs_equivalence (#15571)

* fix test_model_outputs_equivalence

* fix tuple outputs for blenderbot
parent fcb4f11c
...@@ -719,7 +719,7 @@ class FlaxBlenderbotEncoder(nn.Module): ...@@ -719,7 +719,7 @@ class FlaxBlenderbotEncoder(nn.Module):
last_hidden_states = self.layer_norm(last_hidden_states) last_hidden_states = self.layer_norm(last_hidden_states)
if not return_dict: if not return_dict:
return outputs return (last_hidden_states,) + outputs[1:]
return FlaxBaseModelOutput( return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states, last_hidden_state=last_hidden_states,
...@@ -797,7 +797,7 @@ class FlaxBlenderbotDecoder(nn.Module): ...@@ -797,7 +797,7 @@ class FlaxBlenderbotDecoder(nn.Module):
last_hidden_states = self.layer_norm(last_hidden_states) last_hidden_states = self.layer_norm(last_hidden_states)
if not return_dict: if not return_dict:
return outputs return (last_hidden_states,) + outputs[1:]
return FlaxBaseModelOutputWithPastAndCrossAttentions( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states, last_hidden_state=last_hidden_states,
......
...@@ -134,10 +134,6 @@ class FlaxModelTesterMixin: ...@@ -134,10 +134,6 @@ class FlaxModelTesterMixin:
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
...@@ -149,11 +145,9 @@ class FlaxModelTesterMixin: ...@@ -149,11 +145,9 @@ class FlaxModelTesterMixin:
elif tuple_object is None: elif tuple_object is None:
return return
else: else:
self.assert_almost_equals( self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5)
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
)
recursive_check(tuple_output, dict_output) recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment