"vscode:/vscode.git/clone" did not exist on "37378898a204c91e4aa47f163dd635a06c437628"
Unverified Commit 96b2b2de authored by Ogundepo Odunayo's avatar Ogundepo Odunayo Committed by GitHub
Browse files

Extend Script to enable conversion of Encoder Only T5x Models to Pytorch (#20907)



* add converter for t5x_retrieval model

* update args

* Update src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* style  editing -> convert t5x to pytorch

* make style
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 91ff7efe
...@@ -35,7 +35,7 @@ import torch ...@@ -35,7 +35,7 @@ import torch
from flax import traverse_util from flax import traverse_util
from t5x import checkpoints from t5x import checkpoints
from transformers import T5Config, T5ForConditionalGeneration from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
from transformers.utils import logging from transformers.utils import logging
...@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name): ...@@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"] return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch.""" """Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"]) old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()} old = {"/".join(k): v for k, v in old.items()}
...@@ -110,50 +110,51 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): ...@@ -110,50 +110,51 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
].T ].T
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"] new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
# Decoder. if not is_encoder_only:
for i in range(num_layers): # Decoder.
# Block i, layer 0 (Self Attention). for i in range(num_layers):
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") # Block i, layer 0 (Self Attention).
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
# Block i, layer 1 (Cross Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm") # Block i, layer 1 (Cross Attention).
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention") layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") # Block i, layer 2 (MLP).
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
if split_mlp_wi: new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T if split_mlp_wi:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
else: new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T else:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[ new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
"decoder/relpos_bias/rel_embedding" new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
].T "decoder/relpos_bias/rel_embedding"
].T
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old: # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new return new
def make_state_dict(converted_params): def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model.""" """Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors. # Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
...@@ -162,35 +163,41 @@ def make_state_dict(converted_params): ...@@ -162,35 +163,41 @@ def make_state_dict(converted_params):
if "encoder.embed_tokens.weight" not in state_dict: if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "decoder.embed_tokens.weight" not in state_dict: if not is_encoder_only:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"] if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "lm_head.weight" not in state_dict: # For old 1.0 models. if "lm_head.weight" not in state_dict: # For old 1.0 models.
print("Using shared word embeddings as lm_head.") print("Using shared word embeddings as lm_head.")
state_dict["lm_head.weight"] = state_dict["shared.weight"] state_dict["lm_head.weight"] = state_dict["shared.weight"]
return state_dict return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params.""" """Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers) converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
state_dict = make_state_dict(converted) state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path): def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model # Initialise PyTorch model
config = T5Config.from_json_file(config_file) config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}") print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all. # Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings. # The v1.0 checkpoints will simply have an LM head that is the word embeddings.
model = T5ForConditionalGeneration(config) if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
# Save pytorch-model # Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}") print(f"Save PyTorch model to {pytorch_dump_path}")
...@@ -217,5 +224,10 @@ if __name__ == "__main__": ...@@ -217,5 +224,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
) )
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args() args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment