Unverified Commit 692e61e9 authored by Crystina's avatar Crystina Committed by GitHub
Browse files

Flax t5 Encoder (#17784)



* first draft adding Flax-t5-encoder and Flax-mt5-encoder

* imports

* after make fixup

* flax t5 encoder test

* black on test

* make fix-copies

* clean

* all_model_classes -> tuple

* clean test

* is_encoder_decoder=False in t5-enc tester

* remove file docstring before FlaxT5Encoder

* black

* isort

* commit suggestions on src/transformers/models/t5/modeling_flax_t5.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* commit suggestions on src/transformers/models/t5/modeling_flax_t5.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* remove _get_encoder_module

* self.decoder_seq_length -> self.encoder_seq_length as t5-enc does not have decoder

* bugfix - self.module_class is class itself, not instance;

* docs for mt5 and t5

* call -> __call__ in t5 doc

* FlaxMT5EncoderModel to TYPE_HINT

* run doc-builder to allow change the files
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent eb1493b1
...@@ -96,3 +96,7 @@ See [`T5TokenizerFast`] for all details. ...@@ -96,3 +96,7 @@ See [`T5TokenizerFast`] for all details.
## FlaxMT5ForConditionalGeneration ## FlaxMT5ForConditionalGeneration
[[autodoc]] FlaxMT5ForConditionalGeneration [[autodoc]] FlaxMT5ForConditionalGeneration
## FlaxMT5EncoderModel
[[autodoc]] FlaxMT5EncoderModel
...@@ -371,3 +371,8 @@ T5 is supported by several example scripts, both for pre-training and fine-tunin ...@@ -371,3 +371,8 @@ T5 is supported by several example scripts, both for pre-training and fine-tunin
- __call__ - __call__
- encode - encode
- decode - decode
## FlaxT5EncoderModel
[[autodoc]] FlaxT5EncoderModel
- __call__
...@@ -2704,7 +2704,7 @@ else: ...@@ -2704,7 +2704,7 @@ else:
"FlaxMBartPreTrainedModel", "FlaxMBartPreTrainedModel",
] ]
) )
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.opt"].extend( _import_structure["models.opt"].extend(
[ [
"FlaxOPTForCausalLM", "FlaxOPTForCausalLM",
...@@ -2743,7 +2743,9 @@ else: ...@@ -2743,7 +2743,9 @@ else:
] ]
) )
_import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel") _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.t5"].extend(
["FlaxT5EncoderModel", "FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]
)
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
...@@ -4974,7 +4976,7 @@ if TYPE_CHECKING: ...@@ -4974,7 +4976,7 @@ if TYPE_CHECKING:
FlaxMBartModel, FlaxMBartModel,
FlaxMBartPreTrainedModel, FlaxMBartPreTrainedModel,
) )
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import ( from .models.roberta import (
...@@ -4997,7 +4999,7 @@ if TYPE_CHECKING: ...@@ -4997,7 +4999,7 @@ if TYPE_CHECKING:
FlaxRoFormerPreTrainedModel, FlaxRoFormerPreTrainedModel,
) )
from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
......
...@@ -67,7 +67,7 @@ try: ...@@ -67,7 +67,7 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"] _import_structure["modeling_flax_mt5"] = ["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -95,7 +95,7 @@ if TYPE_CHECKING: ...@@ -95,7 +95,7 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from .modeling_flax_mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
else: else:
import sys import sys
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import numpy as np import numpy as np
from ...utils import logging from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from ..t5.modeling_flax_t5 import FlaxT5EncoderModel, FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config from .configuration_mt5 import MT5Config
...@@ -67,6 +67,33 @@ class FlaxMT5Model(FlaxT5Model): ...@@ -67,6 +67,33 @@ class FlaxMT5Model(FlaxT5Model):
config_class = MT5Config config_class = MT5Config
class FlaxMT5EncoderModel(FlaxT5EncoderModel):
r"""
This class overrides [`FlaxT5EncoderModel`]. Please check the superclass for the appropriate documentation
alongside usage examples.
Examples:
```python
>>> from transformers import FlaxT5EncoderModel, T5Tokenizer
>>> model = FlaxT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer():
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"])
>>> hidden_states = outputs.last_hidden_state
```"""
model_type = "mt5"
config_class = MT5Config
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration): class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
r""" r"""
This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate This class overrides [`FlaxT5ForConditionalGeneration`]. Please check the superclass for the appropriate
......
...@@ -83,6 +83,7 @@ except OptionalDependencyNotAvailable: ...@@ -83,6 +83,7 @@ except OptionalDependencyNotAvailable:
pass pass
else: else:
_import_structure["modeling_flax_t5"] = [ _import_structure["modeling_flax_t5"] = [
"FlaxT5EncoderModel",
"FlaxT5ForConditionalGeneration", "FlaxT5ForConditionalGeneration",
"FlaxT5Model", "FlaxT5Model",
"FlaxT5PreTrainedModel", "FlaxT5PreTrainedModel",
...@@ -143,7 +144,12 @@ if TYPE_CHECKING: ...@@ -143,7 +144,12 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
else: else:
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .modeling_flax_t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
FlaxT5Model,
FlaxT5PreTrainedModel,
)
else: else:
......
...@@ -929,18 +929,18 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): ...@@ -929,18 +929,18 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
input_ids = jnp.zeros(input_shape, dtype="i4") input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids) attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = jnp.ones_like(input_ids) args = [input_ids, attention_mask]
decoder_attention_mask = jnp.ones_like(input_ids) if self.module_class not in [FlaxT5EncoderModule]:
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
args.extend([decoder_input_ids, decoder_attention_mask])
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init( random_params = self.module.init(
rngs, rngs,
input_ids, *args,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
)["params"] )["params"]
if params is not None: if params is not None:
...@@ -1357,6 +1357,90 @@ overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTR ...@@ -1357,6 +1357,90 @@ overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTR
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class FlaxT5EncoderModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
)
encoder_config = copy.deepcopy(self.config)
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
def __call__(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
deterministic: bool = True,
):
# Encode if needed (training, first prediction pass)
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
return encoder_outputs
class FlaxT5EncoderModel(FlaxT5PreTrainedModel):
module_class = FlaxT5EncoderModule
@add_start_docstrings_to_model_forward(T5_ENCODE_INPUTS_DOCSTRING)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class FlaxT5ForConditionalGenerationModule(nn.Module): class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config config: T5Config
......
...@@ -802,6 +802,13 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject): ...@@ -802,6 +802,13 @@ class FlaxMBartPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxMT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMT5ForConditionalGeneration(metaclass=DummyObject): class FlaxMT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -970,6 +977,13 @@ class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject): ...@@ -970,6 +977,13 @@ class FlaxSpeechEncoderDecoderModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxT5EncoderModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration(metaclass=DummyObject): class FlaxT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -48,7 +48,12 @@ if is_flax_available(): ...@@ -48,7 +48,12 @@ if is_flax_available():
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right from transformers.models.t5.modeling_flax_t5 import (
FlaxT5EncoderModel,
FlaxT5ForConditionalGeneration,
FlaxT5Model,
shift_tokens_right,
)
class FlaxT5ModelTester: class FlaxT5ModelTester:
...@@ -461,6 +466,298 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest. ...@@ -461,6 +466,298 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
class FlaxT5EncoderOnlyModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
encoder_seq_length=7,
# For common tests
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
# For common tests
self.seq_length = self.encoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = 0
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
config = T5Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
is_encoder_decoder=False,
)
return (
config,
input_ids,
attention_mask,
)
def create_and_check_model(
self,
config,
input_ids,
attention_mask,
):
model = FlaxT5EncoderModel(config=config)
result = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
result = model(input_ids=input_ids)
encoder_output = result.last_hidden_state
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_flax
class FlaxT5EncoderOnlyModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxT5EncoderModel,) if is_flax_available() else ()
is_encoder_decoder = False
def setUp(self):
self.model_tester = FlaxT5EncoderOnlyModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite since special base model prefix is used
def test_save_load_from_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
def test_save_load_to_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_bf16_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@require_flax @require_flax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment