Unverified Commit a6885db9 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax tests] fix test_model_outputs_equivalence (#15571)

* fix test_model_outputs_equivalence

* fix tuple outputs for blenderbot
parent fcb4f11c
......@@ -719,7 +719,7 @@ class FlaxBlenderbotEncoder(nn.Module):
last_hidden_states = self.layer_norm(last_hidden_states)
if not return_dict:
return outputs
return (last_hidden_states,) + outputs[1:]
return FlaxBaseModelOutput(
last_hidden_state=last_hidden_states,
......@@ -797,7 +797,7 @@ class FlaxBlenderbotDecoder(nn.Module):
last_hidden_states = self.layer_norm(last_hidden_states)
if not return_dict:
return outputs
return (last_hidden_states,) + outputs[1:]
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_states,
......
......@@ -134,10 +134,6 @@ class FlaxModelTesterMixin:
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
......@@ -149,11 +145,9 @@ class FlaxModelTesterMixin:
elif tuple_object is None:
return
else:
self.assert_almost_equals(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
)
self.assert_almost_equals(jnp.nan_to_num(tuple_object), jnp.nan_to_num(dict_object), 1e-5)
recursive_check(tuple_output, dict_output)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment