Unverified Commit 96b2b2de authored by Ogundepo Odunayo's avatar Ogundepo Odunayo Committed by GitHub
Browse files

Extend Script to enable conversion of Encoder Only T5x Models to Pytorch (#20907)



* add converter for t5x_retrieval model

* update args

* Update src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* style  editing -> convert t5x to pytorch

* make style
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 91ff7efe
......@@ -35,7 +35,7 @@ import torch
from flax import traverse_util
from t5x import checkpoints
from transformers import T5Config, T5ForConditionalGeneration
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
from transformers.utils import logging
......@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
......@@ -110,6 +110,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
].T
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
if not is_encoder_only:
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
......@@ -153,7 +154,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
return new
def make_state_dict(converted_params):
def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
......@@ -162,6 +163,7 @@ def make_state_dict(converted_params):
if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
if not is_encoder_only:
if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
......@@ -172,25 +174,30 @@ def make_state_dict(converted_params):
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers)
state_dict = make_state_dict(converted)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
......@@ -217,5 +224,10 @@ if __name__ == "__main__":
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment