Unverified Commit 60e448c8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Correct pt to flax conversion if from base to head (#13006)

* finish PR

* add tests

* correct tests

* finish

* correct other flax tests

* better naming

* correct naming

* finish

* apply sylvains suggestions
parent 33929448
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import os import os
from pickle import UnpicklingError from pickle import UnpicklingError
from typing import Dict, Tuple
import numpy as np import numpy as np
...@@ -58,60 +59,97 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa ...@@ -58,60 +59,97 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
return flax_state_dict return flax_state_dict
def rename_key_and_reshape_tensor(
pt_tuple_key: Tuple[str],
pt_tensor: np.ndarray,
random_flax_state_dict: Dict[str, jnp.ndarray],
model_prefix: str,
) -> (Tuple[str], np.ndarray):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
"""Checks if ``key`` of ``(prefix,) + key`` is in random_flax_state_dict"""
return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0
# layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# embedding
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy # convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params) random_flax_state_dict = flatten_dict(flax_model.params)
flax_state_dict = {} flax_state_dict = {}
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
) )
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
) )
# Need to change some parameters name to match Flax names so that we don't have to fork any layer # Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items(): for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split(".")) pt_tuple_key = tuple(pt_key.split("."))
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix # remove base model prefix if necessary
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict has_base_model_prefix = pt_tuple_key[0] == model_prefix
if load_model_with_head_into_base_model and has_base_model_prefix:
if remove_base_model_prefix and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:] pt_tuple_key = pt_tuple_key[1:]
elif add_base_model_prefix and require_base_model_prefix:
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
# Correctly rename weight parameters # Correctly rename weight parameters
if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key = pt_tuple_key[:-1] + ("scale",) pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: )
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: # add model prefix if necessary
# conv layer require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if load_base_model_into_model_with_head and require_base_model_prefix:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0) flax_key = (model_prefix,) + flax_key
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
# linear layer if flax_key in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma":
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
elif pt_tuple_key[-1] == "beta":
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key in random_flax_state_dict:
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
raise ValueError( raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
) )
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict) return unflatten_dict(flax_state_dict)
...@@ -154,10 +192,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -154,10 +192,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
flax_state_dict = flatten_dict(flax_state) flax_state_dict = flatten_dict(flax_state)
pt_model_dict = pt_model.state_dict() pt_model_dict = pt_model.state_dict()
remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and ( load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()]) pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
) )
add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and ( load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()]) pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
) )
...@@ -170,9 +208,9 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -170,9 +208,9 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
# adapt flax_key to prepare for loading from/to base model only # adapt flax_key to prepare for loading from/to base model only
if remove_base_model_prefix and has_base_model_prefix: if load_model_with_head_into_base_model and has_base_model_prefix:
flax_key_tuple = flax_key_tuple[1:] flax_key_tuple = flax_key_tuple[1:]
elif add_base_model_prefix and require_base_model_prefix: elif load_base_model_into_model_with_head and require_base_model_prefix:
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
# rename flax weights to PyTorch format # rename flax weights to PyTorch format
......
...@@ -213,9 +213,20 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -213,9 +213,20 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_from_base(self): def test_save_load_from_base(self):
pass pass
# FlaxCLIPVisionModel does not have any base model
def test_save_load_to_base(self): def test_save_load_to_base(self):
pass pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
...@@ -307,9 +318,20 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -307,9 +318,20 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_from_base(self): def test_save_load_from_base(self):
pass pass
# FlaxCLIPVisionModel does not have any base model
def test_save_load_to_base(self): def test_save_load_to_base(self):
pass pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes: for model_class_name in self.all_model_classes:
......
...@@ -334,6 +334,63 @@ class FlaxModelTesterMixin: ...@@ -334,6 +334,63 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow @slow
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -17,8 +17,15 @@ import unittest ...@@ -17,8 +17,15 @@ import unittest
import numpy as np import numpy as np
import transformers
from transformers import is_flax_available from transformers import is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flax,
require_sentencepiece,
require_tokenizers,
slow,
)
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_generation_flax_utils import FlaxGenerationTesterMixin from .test_generation_flax_utils import FlaxGenerationTesterMixin
...@@ -40,6 +47,7 @@ if is_flax_available(): ...@@ -40,6 +47,7 @@ if is_flax_available():
from flax.training.common_utils import onehot from flax.training.common_utils import onehot
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
...@@ -363,6 +371,65 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest. ...@@ -363,6 +371,65 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment