Unverified Commit b71c38a0 authored by JiangZhongqing's avatar JiangZhongqing Committed by GitHub
Browse files

Fix bug for T5x to PyTorch convert script with varying encoder and decoder layers (#27448)

* Fix bug in handling varying encoder and decoder layers

This commit resolves an issue where the script failed to convert T5x models to PyTorch models when the number of decoder layers differed from the number of encoder layers.  I've addressed this issue by passing an additional 'num_decoder_layers' parameter to the relevant function.

* Fix bug in handling varying encoder and decoder layers
parent 2e72bbab
...@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name): ...@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"] return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool): def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, num_decoder_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch.""" """Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"]) old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()} old = {"/".join(k): v for k, v in old.items()}
...@@ -112,7 +112,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: ...@@ -112,7 +112,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only:
if not is_encoder_only: if not is_encoder_only:
# Decoder. # Decoder.
for i in range(num_layers): for i in range(num_decoder_layers):
# Block i, layer 0 (Self Attention). # Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
...@@ -177,7 +177,12 @@ def make_state_dict(converted_params, is_encoder_only: bool): ...@@ -177,7 +177,12 @@ def make_state_dict(converted_params, is_encoder_only: bool):
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params.""" """Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only) converted = convert_t5x_to_pytorch(
variables,
num_layers=config.num_layers,
num_decoder_layers=config.num_decoder_layers,
is_encoder_only=is_encoder_only,
)
state_dict = make_state_dict(converted, is_encoder_only) state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment