Unverified Commit 170fcaa6 authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

Generalize decay_mask_fn to apply mask to all LayerNorm params (#18273)

* generalize decay_mask_fn to find all layernorm params

* fixup

* generalising decay_mask_fn
parent 83d2d745
...@@ -875,15 +875,19 @@ def main(): ...@@ -875,15 +875,19 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a # to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters. # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBart.
# For FlaxT5, one should correct the layer norm parameter naming
# accordingly - see `run_t5_mlm_flax.py` e.g.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
layer_norm_params = [ # find out all LayerNorm parameters
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
] layer_norm_named_params = set(
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} [
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
......
...@@ -638,15 +638,19 @@ def main(): ...@@ -638,15 +638,19 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a # to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters. # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxGPT2.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
flat_mask = { # find out all LayerNorm parameters
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
for path in flat_params layer_norm_named_params = set(
} [
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
......
...@@ -658,12 +658,19 @@ def main(): ...@@ -658,12 +658,19 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a # to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters. # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
......
...@@ -326,7 +326,6 @@ class FlaxDataCollatorForT5MLM: ...@@ -326,7 +326,6 @@ class FlaxDataCollatorForT5MLM:
decoder_start_token_id: int decoder_start_token_id: int
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
# convert list to dict and tensorize input # convert list to dict and tensorize input
batch = BatchEncoding( batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
...@@ -395,7 +394,6 @@ class FlaxDataCollatorForT5MLM: ...@@ -395,7 +394,6 @@ class FlaxDataCollatorForT5MLM:
return input_ids return input_ids
def random_spans_noise_mask(self, length): def random_spans_noise_mask(self, length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens. Noise mask consisting of random spans of noise tokens.
...@@ -782,10 +780,17 @@ def main(): ...@@ -782,10 +780,17 @@ def main():
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
flat_mask = { # find out all LayerNorm parameters
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
for path in flat_params layer_norm_named_params = set(
} [
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
......
...@@ -327,12 +327,19 @@ def create_train_state( ...@@ -327,12 +327,19 @@ def create_train_state(
# to bias and LayerNorm scale parameters. decay_mask_fn returns a # to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters. # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw( tx = optax.adamw(
......
...@@ -723,15 +723,19 @@ def main(): ...@@ -723,15 +723,19 @@ def main():
# to bias and LayerNorm scale parameters. decay_mask_fn returns a # to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters. # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBart.
# For FlaxT5, one should correct the layer norm parameter naming
# accordingly - see `run_t5_mlm_flax.py` e.g.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
layer_norm_params = [ # find out all LayerNorm parameters
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
] layer_norm_named_params = set(
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} [
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
......
...@@ -226,7 +226,17 @@ def create_train_state( ...@@ -226,7 +226,17 @@ def create_train_state(
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw( tx = optax.adamw(
......
...@@ -284,12 +284,19 @@ def create_train_state( ...@@ -284,12 +284,19 @@ def create_train_state(
# to bias and LayerNorm scale parameters. decay_mask_fn returns a # to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters. # mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed. # The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params): def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params) flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params} # find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = set(
[
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
]
)
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw( tx = optax.adamw(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment