"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e7d52a10d721f4475c810d403b1e71689d4b94b9"
Unverified Commit aebca696 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix missing output_attentions in PT/Flax equivalence test (#16271)



* fix - set output_attentions to True

* Update tests/test_modeling_flax_common.py

* update for has_attentions

* overwrite check_outputs in FlaxBigBirdModelTest
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 45abb37a
...@@ -190,3 +190,12 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -190,3 +190,12 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
# overwrite from common in order to skip the check on `attentions`
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
return
else:
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
...@@ -120,6 +120,7 @@ class FlaxModelTesterMixin: ...@@ -120,6 +120,7 @@ class FlaxModelTesterMixin:
test_mismatched_shapes = True test_mismatched_shapes = True
is_encoder_decoder = False is_encoder_decoder = False
test_head_masking = False test_head_masking = False
has_attentions = True
def _prepare_for_class(self, inputs_dict, model_class): def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
...@@ -168,6 +169,7 @@ class FlaxModelTesterMixin: ...@@ -168,6 +169,7 @@ class FlaxModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class) dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
def check_outputs(self, fx_outputs, pt_outputs, model_class, names): def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
""" """
Args: Args:
...@@ -204,8 +206,7 @@ class FlaxModelTesterMixin: ...@@ -204,8 +206,7 @@ class FlaxModelTesterMixin:
pt_outputs[pt_nans] = 0 pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0 fx_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
self.assertLessEqual(max_diff, 1e-5)
else: else:
raise ValueError( raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead." f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
...@@ -222,6 +223,7 @@ class FlaxModelTesterMixin: ...@@ -222,6 +223,7 @@ class FlaxModelTesterMixin:
# Output all for aggressive testing # Output all for aggressive testing
config.output_hidden_states = True config.output_hidden_states = True
config.output_attentions = self.has_attentions
# prepare inputs # prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
...@@ -274,7 +276,7 @@ class FlaxModelTesterMixin: ...@@ -274,7 +276,7 @@ class FlaxModelTesterMixin:
# Output all for aggressive testing # Output all for aggressive testing
config.output_hidden_states = True config.output_hidden_states = True
# Pure convolutional models have no attention config.output_attentions = self.has_attentions
# prepare inputs # prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
...@@ -314,6 +316,7 @@ class FlaxModelTesterMixin: ...@@ -314,6 +316,7 @@ class FlaxModelTesterMixin:
# send pytorch model to the correct device # send pytorch model to the correct device
pt_model_loaded.to(torch_device) pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad(): with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs) pt_outputs_loaded = pt_model_loaded(**pt_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment