"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "29e4597950e9de8cc32321e4286f1a61894b46d4"
Unverified Commit eb881674 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[Flax] [WIP] allow loading head model with base model weights (#12255)

* boom boom

* remove flax clip example

* allow loading head model with base model weights

* add test

* fix imports

* disable save, load test for clip

* add test_save_load_to_base
parent 8d5b7f36
...@@ -348,6 +348,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -348,6 +348,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix] state = state[cls.base_model_prefix]
# if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
state = {cls.base_model_prefix: state}
# flatten dicts # flatten dicts
state = flatten_dict(state) state = flatten_dict(state)
......
...@@ -209,6 +209,13 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -209,6 +209,13 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
[self.model_tester.num_attention_heads, seq_length, seq_length], [self.model_tester.num_attention_heads, seq_length, seq_length],
) )
# FlaxCLIPVisionModel does not have any base model
def test_save_load_from_base(self):
pass
def test_save_load_to_base(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
...@@ -296,6 +303,13 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -296,6 +303,13 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def setUp(self): def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self) self.model_tester = FlaxCLIPTextModelTester(self)
# FlaxCLIPTextModel does not have any base model
def test_save_load_from_base(self):
pass
def test_save_load_to_base(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
......
...@@ -32,7 +32,9 @@ if is_flax_available(): ...@@ -32,7 +32,9 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
from transformers.modeling_flax_pytorch_utils import ( from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax, convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model, load_flax_weights_in_pytorch_model,
...@@ -273,6 +275,50 @@ class FlaxModelTesterMixin: ...@@ -273,6 +275,50 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs): for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3) self.assert_almost_equals(output_loaded, output, 1e-3)
def test_save_load_from_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_to_base(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow @slow
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment