Unverified Commit f571dc20 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Update PT Flax equivalence tests in PT test file (#16280)



* update PT/Flax equivalence tests on PT side

* overwrite check_outputs in BigBirdModelTest
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 41bfc1e2
...@@ -596,6 +596,15 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -596,6 +596,15 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
# overwrite from common in order to skip the check on `attentions`
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
return
else:
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
@require_torch @require_torch
@slow @slow
......
...@@ -1660,8 +1660,9 @@ class ModelTesterMixin: ...@@ -1660,8 +1660,9 @@ class ModelTesterMixin:
# transformers does not have TF version yet # transformers does not have TF version yet
return return
if self.has_attentions: # Output all for aggressive testing
config.output_attentions = True config.output_hidden_states = True
config.output_attentions = self.has_attentions
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]: for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
if k in inputs_dict: if k in inputs_dict:
...@@ -1728,28 +1729,75 @@ class ModelTesterMixin: ...@@ -1728,28 +1729,75 @@ class ModelTesterMixin:
diff = np.abs((a - b)).max() diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self): def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__ fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name): if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class
return return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name) fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class # load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32) fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args # make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
...@@ -1759,29 +1807,41 @@ class ModelTesterMixin: ...@@ -1759,29 +1807,41 @@ class ModelTesterMixin:
# remove function args that don't exist in Flax # remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
# send pytorch inputs to the correct device
pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
# convert inputs to Flax fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(fx_keys, pt_keys)
for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() fx_outputs_loaded = fx_model_loaded(**fx_inputs)
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
) pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self): def test_equivalence_flax_to_pt(self):
...@@ -1789,59 +1849,78 @@ class ModelTesterMixin: ...@@ -1789,59 +1849,78 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
# load corresponding PyTorch class
pt_model = model_class(config).eval()
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model_class_name = "Flax" + model_class.__name__ fx_model_class_name = "Flax" + model_class.__name__
if not hasattr(transformers, fx_model_class_name): if not hasattr(transformers, fx_model_class_name):
# no flax model exists for this class # no flax model exists for this class
return return
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
fx_model_class = getattr(transformers, fx_model_class_name) fx_model_class = getattr(transformers, fx_model_class_name)
# load PyTorch class
pt_model = model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
# load Flax class # load Flax class
fx_model = fx_model_class(config, dtype=jnp.float32) fx_model = fx_model_class(config, dtype=jnp.float32)
# make sure only flax inputs are forward that actually exist in function args # make sure only flax inputs are forward that actually exist in function args
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# prepare inputs # prepare inputs
pt_inputs = self._prepare_for_class(inputs_dict, model_class) pt_inputs = self._prepare_for_class(inputs_dict, model_class)
# remove function args that don't exist in Flax # remove function args that don't exist in Flax
pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
with torch.no_grad(): # send pytorch inputs to the correct device
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_inputs = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_outputs = fx_model(**fx_inputs).to_tuple() pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assertEqual(fx_keys, pt_keys)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad(): with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() pt_outputs_loaded = pt_model_loaded(**pt_inputs)
self.assertEqual( fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): self.assertEqual(fx_keys, pt_keys)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
def test_inputs_embeds(self): def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment