Unverified Commit f7ca656f authored by Shubhamai's avatar Shubhamai Committed by GitHub
Browse files

[Flax] adding support for batch norm layers (#21581)

* [flax] adding support for batch norm layers

* fixing bugs related to pt+flax integration

* cleanup, batchnorm support in sharded pt to flax

* support for batchnorm tests in pt+flax integration

* simplifying checking batch norm layer
parent 279008ad
...@@ -83,6 +83,16 @@ def rename_key_and_reshape_tensor( ...@@ -83,6 +83,16 @@ def rename_key_and_reshape_tensor(
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor return renamed_pt_tuple_key, pt_tensor
# batch norm layer mean
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# batch norm layer var
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
return renamed_pt_tuple_key, pt_tensor
# embedding # embedding
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key): if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
...@@ -118,13 +128,25 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -118,13 +128,25 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)
# use params dict if the model contains batch norm layers
if "params" in flax_model.params:
flax_model_params = flax_model.params["params"]
else:
flax_model_params = flax_model.params
random_flax_state_dict = flatten_dict(flax_model_params)
# add batch_stats keys,values to dict
if "batch_stats" in flax_model.params:
flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
random_flax_state_dict.update(flax_batch_stats)
flax_state_dict = {} flax_state_dict = {}
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()} model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()} model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
...@@ -154,8 +176,22 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -154,8 +176,22 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
) )
# also add unexpected weight so that warning is thrown # add batch stats if the model contains batchnorm layers
flax_state_dict[flax_key] = jnp.asarray(flax_tensor) if "batch_stats" in flax_model.params:
if "mean" in flax_key[-1] or "var" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
# remove num_batches_tracked key
if "num_batches_tracked" in flax_key[-1]:
flax_state_dict.pop(flax_key, None)
continue
# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict) return unflatten_dict(flax_state_dict)
...@@ -176,12 +212,21 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -176,12 +212,21 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
random_flax_state_dict = flatten_dict(flax_model.params)
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
if "batch_stats" in flax_model.params:
flax_model_params = flax_model.params["params"]
random_flax_state_dict = flatten_dict(flax_model_params)
random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
else:
flax_model_params = flax_model.params
random_flax_state_dict = flatten_dict(flax_model_params)
load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()} model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()} model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
...@@ -209,8 +254,25 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -209,8 +254,25 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
) )
# also add unexpected weight so that warning is thrown # add batch stats if the model contains batchnorm layers
flax_state_dict[flax_key] = jnp.asarray(flax_tensor) if "batch_stats" in flax_model.params:
if "mean" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
if "var" in flax_key[-1]:
flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
continue
# remove num_batches_tracked key
if "num_batches_tracked" in flax_key[-1]:
flax_state_dict.pop(flax_key, None)
continue
# also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor)
else:
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict) return unflatten_dict(flax_state_dict)
...@@ -299,7 +361,16 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -299,7 +361,16 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
elif flax_key_tuple[-1] in ["scale", "embedding"]: elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",) flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_key = ".".join(flax_key_tuple) # adding batch stats from flax batch norm to pt
elif "mean" in flax_key_tuple[-1]:
flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
elif "var" in flax_key_tuple[-1]:
flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)
if "batch_stats" in flax_state:
flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header
else:
flax_key = ".".join(flax_key_tuple)
if flax_key in pt_model_dict: if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape: if flax_tensor.shape != pt_model_dict[flax_key].shape:
......
...@@ -837,14 +837,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -837,14 +837,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# keep the params on CPU if we don't want to initialize # keep the params on CPU if we don't want to initialize
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# if model is base model only use model_prefix key if "batch_stats" in state: # if flax model contains batch norm layers
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: # if model is base model only use model_prefix key
state = state[cls.base_model_prefix] if (
cls.base_model_prefix not in dict(model.params_shape_tree["params"])
and cls.base_model_prefix in state["params"]
):
state["params"] = state["params"][cls.base_model_prefix]
state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
# if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix
if (
cls.base_model_prefix in dict(model.params_shape_tree["params"])
and cls.base_model_prefix not in state["params"]
):
state = {
"params": {cls.base_model_prefix: state["params"]},
"batch_stats": {cls.base_model_prefix: state["batch_stats"]},
}
else:
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]
# if model is head model and we are loading weights from base model # if model is head model and we are loading weights from base model
# we initialize new params dict with base_model_prefix # we initialize new params dict with base_model_prefix
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
state = {cls.base_model_prefix: state} state = {cls.base_model_prefix: state}
# flatten dicts # flatten dicts
state = flatten_dict(state) state = flatten_dict(state)
...@@ -854,6 +875,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -854,6 +875,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
missing_keys = model.required_params - set(state.keys()) missing_keys = model.required_params - set(state.keys())
unexpected_keys = set(state.keys()) - model.required_params unexpected_keys = set(state.keys()) - model.required_params
# Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
for unexpected_key in unexpected_keys.copy():
if "num_batches_tracked" in unexpected_key[-1]:
unexpected_keys.remove(unexpected_key)
if missing_keys and not _do_init: if missing_keys and not _do_init:
logger.warning( logger.warning(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
......
...@@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None): ...@@ -118,6 +118,30 @@ def random_attention_mask(shape, rng=None):
return attn_mask return attn_mask
def get_params(params, from_head_prefix=None):
"""Function extracts relevant parameters into flatten dict from model params,
appends batch normalization statistics if present"""
# If Both parameters and batch normalization statistics are present
if "batch_stats" in params:
# Extract only parameters for the specified head prefix (if specified) and add batch statistics
if from_head_prefix is not None:
extracted_params = flatten_dict(unfreeze(params["params"][from_head_prefix]))
extracted_params.update(flatten_dict(params["batch_stats"][from_head_prefix]))
else:
extracted_params = flatten_dict(unfreeze(params["params"]))
extracted_params.update(flatten_dict(params["batch_stats"]))
# Only parameters are present
else:
if from_head_prefix is not None:
extracted_params = flatten_dict(unfreeze(params[from_head_prefix]))
else:
extracted_params = flatten_dict(unfreeze(params))
return extracted_params
@require_flax @require_flax
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
...@@ -426,14 +450,14 @@ class FlaxModelTesterMixin: ...@@ -426,14 +450,14 @@ class FlaxModelTesterMixin:
continue continue
model = base_class(config) model = base_class(config)
base_params = flatten_dict(unfreeze(model.params)) base_params = get_params(model.params)
# check that all base model weights are loaded correctly # check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname) head_model = model_class.from_pretrained(tmpdirname)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
for key in base_param_from_head.keys(): for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item() max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
...@@ -448,14 +472,14 @@ class FlaxModelTesterMixin: ...@@ -448,14 +472,14 @@ class FlaxModelTesterMixin:
continue continue
model = model_class(config) model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# check that all base model weights are loaded correctly # check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname)
base_params = flatten_dict(unfreeze(base_model.params)) base_params = get_params(base_model.params)
for key in base_params_from_head.keys(): for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
...@@ -471,7 +495,7 @@ class FlaxModelTesterMixin: ...@@ -471,7 +495,7 @@ class FlaxModelTesterMixin:
continue continue
model = base_class(config) model = base_class(config)
base_params = flatten_dict(unfreeze(model.params)) base_params = get_params(model.params)
# convert Flax model to PyTorch model # convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
...@@ -484,7 +508,7 @@ class FlaxModelTesterMixin: ...@@ -484,7 +508,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True) head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) base_param_from_head = get_params(head_model.params, from_head_prefix=head_model.base_model_prefix)
for key in base_param_from_head.keys(): for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item() max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
...@@ -500,7 +524,7 @@ class FlaxModelTesterMixin: ...@@ -500,7 +524,7 @@ class FlaxModelTesterMixin:
continue continue
model = model_class(config) model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# convert Flax model to PyTorch model # convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
...@@ -512,7 +536,7 @@ class FlaxModelTesterMixin: ...@@ -512,7 +536,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params)) base_params = get_params(base_model.params)
for key in base_params_from_head.keys(): for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
...@@ -529,7 +553,7 @@ class FlaxModelTesterMixin: ...@@ -529,7 +553,7 @@ class FlaxModelTesterMixin:
model = model_class(config) model = model_class(config)
model.params = model.to_bf16(model.params) model.params = model.to_bf16(model.params)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) base_params_from_head = get_params(model.params, from_head_prefix=model.base_model_prefix)
# convert Flax model to PyTorch model # convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
...@@ -541,7 +565,7 @@ class FlaxModelTesterMixin: ...@@ -541,7 +565,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params)) base_params = get_params(base_model.params)
for key in base_params_from_head.keys(): for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment