Unverified Commit d481b641 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make Flax pt-flax equivalence test more aggressive (#15841)



* Make test_equivalence_pt_to_flax more aggressive

* Make test_equivalence_flax_to_pt more aggressive

* don't use to_tuple

* clean-up

* fix missing test cases + testing on GPU

* fix conversion

* fix `ValueError: assignment destination is read-only`

* Add type checking

* commit to revert later

* Fix

* fix

* fix device

* better naming

* clean-up
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent c03b6e42
...@@ -26,7 +26,15 @@ from huggingface_hub import delete_repo, login ...@@ -26,7 +26,15 @@ from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, is_flax_available, is_torch_available from transformers import BertConfig, is_flax_available, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax from transformers.testing_utils import (
PASS,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
torch_device,
)
from transformers.utils import logging from transformers.utils import logging
...@@ -160,15 +168,64 @@ class FlaxModelTesterMixin: ...@@ -160,15 +168,64 @@ class FlaxModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class) dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self): def test_equivalence_pt_to_flax(self):
# It might be better to put this inside the for loop below (because we modify the config there).
# But logically, it is fine.
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
# Output all for aggressive testing
config.output_hidden_states = True
# prepare inputs # prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class # load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
...@@ -183,24 +240,30 @@ class FlaxModelTesterMixin: ...@@ -183,24 +240,30 @@ class FlaxModelTesterMixin:
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state fx_model.params = fx_state
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
) pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self): def test_equivalence_flax_to_pt(self):
...@@ -208,9 +271,14 @@ class FlaxModelTesterMixin: ...@@ -208,9 +271,14 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
# Output all for aggressive testing
config.output_hidden_states = True
# Pure convolutional models have no attention
# prepare inputs # prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class # load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
...@@ -227,27 +295,34 @@ class FlaxModelTesterMixin: ...@@ -227,27 +295,34 @@ class FlaxModelTesterMixin:
# make sure weights are tied in PyTorch # make sure weights are tied in PyTorch
pt_model.tie_weights() pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assertEqual(fx_keys, pt_keys)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
with torch.no_grad(): with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() pt_outputs_loaded = pt_model_loaded(**pt_inputs)
self.assertEqual( fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): self.assertEqual(fx_keys, pt_keys)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment