"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e8da77d181d316dd5c890e56e92cf153c7d68ee7"
Unverified Commit 0b98ca36 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Adapt Flax models to new structure (#9484)



* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Fix code quality

* Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016

* Remove redundant ElectraPooler

* save intermediate

* adapt

* correct bert flax design

* adapt roberta as well

* finish roberta flax

* finish

* apply suggestions

* apply suggestions
Co-authored-by: default avatarChris Nguyen <anhtu2687@gmail.com>
parent 5c0bf397
...@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None): ...@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
return attn_mask return attn_mask
@require_flax
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
...@@ -90,7 +91,7 @@ class FlaxModelTesterMixin: ...@@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**inputs_dict) fx_outputs = fx_model(**inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
...@@ -103,7 +104,6 @@ class FlaxModelTesterMixin: ...@@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -121,7 +121,6 @@ class FlaxModelTesterMixin: ...@@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs): for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3) self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -144,7 +143,6 @@ class FlaxModelTesterMixin: ...@@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self): def test_naming_convention(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model_class_name = model_class.__name__ model_class_name = model_class.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment