Unverified Commit 96b2b2de authored by Ogundepo Odunayo's avatar Ogundepo Odunayo Committed by GitHub
Browse files

Extend Script to enable conversion of Encoder Only T5x Models to Pytorch (#20907)



* add converter for t5x_retrieval model

* update args

* Update src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* style  editing -> convert t5x to pytorch

* make style
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 91ff7efe
...@@ -35,7 +35,7 @@ import torch ...@@ -35,7 +35,7 @@ import torch
from flax import traverse_util from flax import traverse_util
from t5x import checkpoints from t5x import checkpoints
from transformers import T5Config, T5ForConditionalGeneration from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
from transformers.utils import logging from transformers.utils import logging
...@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name): ...@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"] return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch.""" """Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"]) old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()} old = {"/".join(k): v for k, v in old.items()}
...@@ -110,50 +110,51 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): ...@@ -110,50 +110,51 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
].T ].T
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
# Decoder. if not is_encoder_only:
for i in range(num_layers): # Decoder.
# Block i, layer 0 (Self Attention). for i in range(num_layers):
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") # Block i, layer 0 (Self Attention).
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
# Block i, layer 1 (Cross Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") # Block i, layer 1 (Cross Attention).
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention") layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") # Block i, layer 2 (MLP).
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
if split_mlp_wi: new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T if split_mlp_wi:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
else: new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T else:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
"decoder/relpos_bias/rel_embedding" new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
].T "decoder/relpos_bias/rel_embedding"
].T
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old: # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new return new
def make_state_dict(converted_params): def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model.""" """Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors. # Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
...@@ -162,35 +163,41 @@ def make_state_dict(converted_params): ...@@ -162,35 +163,41 @@ def make_state_dict(converted_params):
if "encoder.embed_tokens.weight" not in state_dict: if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "decoder.embed_tokens.weight" not in state_dict: if not is_encoder_only:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "lm_head.weight" not in state_dict: # For old 1.0 models. if "lm_head.weight" not in state_dict: # For old 1.0 models.
print("Using shared word embeddings as lm_head.") print("Using shared word embeddings as lm_head.")
state_dict["lm_head.weight"] = state_dict["shared.weight"] state_dict["lm_head.weight"] = state_dict["shared.weight"]
return state_dict return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params.""" """Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers) converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
state_dict = make_state_dict(converted) state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path): def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model # Initialise PyTorch model
config = T5Config.from_json_file(config_file) config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}") print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all. # Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings. # The v1.0 checkpoints will simply have an LM head that is the word embeddings.
model = T5ForConditionalGeneration(config) if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
# Save pytorch-model # Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}") print(f"Save PyTorch model to {pytorch_dump_path}")
...@@ -217,5 +224,10 @@ if __name__ == "__main__": ...@@ -217,5 +224,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
) )
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args() args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment