"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e42869b091c4bba9f5b2007196d4adceb54d3b28"
Unverified Commit 96b2b2de authored by Ogundepo Odunayo's avatar Ogundepo Odunayo Committed by GitHub
Browse files

Extend Script to enable conversion of Encoder Only T5x Models to Pytorch (#20907)



* add converter for t5x_retrieval model

* update args

* Update src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* style  editing -> convert t5x to pytorch

* make style
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 91ff7efe
......@@ -35,7 +35,7 @@ import torch
from flax import traverse_util
from t5x import checkpoints
from transformers import T5Config, T5ForConditionalGeneration
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
from transformers.utils import logging
......@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
......@@ -110,6 +110,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
].T
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
if not is_encoder_only:
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
......@@ -153,7 +154,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
return new
def make_state_dict(converted_params):
def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
......@@ -162,6 +163,7 @@ def make_state_dict(converted_params):
if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
if not is_encoder_only:
if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
......@@ -172,25 +174,30 @@ def make_state_dict(converted_params):
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers)
state_dict = make_state_dict(converted)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
......@@ -217,5 +224,10 @@ if __name__ == "__main__":
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment