"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "efeab6a3f1eeaffc2cec350ffce797f209ba38f8"
Unverified Commit ecd6efe7 authored by Guido Novati's avatar Guido Novati Committed by GitHub
Browse files

Fix megatron_gpt2 attention block's causal mask (#12007)



* Fix megatron_gpt2 attention block's causal mask.

* compatibility with checkpoints created with recent versions of Megatron-LM

* added integration test for the released Megatron-GPT2 model

* code style changes

* added option to megatron conversion script to read from config file
Co-authored-by: default avatarGuido Novati <gnovati@nvidia.com>
parent 783b0dd5
...@@ -24,6 +24,8 @@ import zipfile ...@@ -24,6 +24,8 @@ import zipfile
import torch import torch
from transformers import GPT2Config
#################################################################################################### ####################################################################################################
...@@ -48,17 +50,45 @@ def recursive_print(name, val, spaces=0): ...@@ -48,17 +50,45 @@ def recursive_print(name, val, spaces=0):
print(msg, ":", val) print(msg, ":", val)
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
# Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
# for compatibility with later versions of NVIDIA Megatron-LM.
# The inverse operation is performed inside Megatron-LM to read checkpoints:
# https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
# If param is the weight tensor of the self-attention block, the returned tensor
# will have to be transposed one more time to be read by HuggingFace GPT2.
input_shape = param.size()
if checkpoint_version == 1.0:
# version 1.0 stores [num_heads * hidden_size * num_splits, :]
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
# other versions store [num_heads * num_splits * hidden_size, :]
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param
#################################################################################################### ####################################################################################################
def convert_megatron_checkpoint(args, input_state_dict): def convert_megatron_checkpoint(args, input_state_dict, config):
# The converted output model. # The converted output model.
output_state_dict = {} output_state_dict = {}
# The number of heads. # The number of heads.
heads = 16 heads = config.n_head
# The hidden_size per head. # The hidden_size per head.
hidden_size_per_head = 64 hidden_size_per_head = config.n_embd // config.n_head
# Megatron-LM checkpoint version
if "checkpoint_version" in input_state_dict.keys():
checkpoint_version = input_state_dict["checkpoint_version"]
else:
checkpoint_version = 0.0
# The model. # The model.
model = input_state_dict["model"] model = input_state_dict["model"]
...@@ -69,22 +99,21 @@ def convert_megatron_checkpoint(args, input_state_dict): ...@@ -69,22 +99,21 @@ def convert_megatron_checkpoint(args, input_state_dict):
# The word embeddings. # The word embeddings.
word_embeddings = embeddings["word_embeddings"]["weight"] word_embeddings = embeddings["word_embeddings"]["weight"]
# Truncate the embedding table to 50257 rows. # Truncate the embedding table to vocab_size rows.
word_embeddings = word_embeddings[:50257, :] word_embeddings = word_embeddings[: config.vocab_size, :]
# Truncate the embedding table to 50257 rows.
output_state_dict["transformer.wte.weight"] = word_embeddings output_state_dict["transformer.wte.weight"] = word_embeddings
# The position embeddings. # The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"] pos_embeddings = embeddings["position_embeddings"]["weight"]
# Read the hidden dimension. # Read the hidden dimension.
hidden_size = pos_embeddings.size(0) n_embed = pos_embeddings.size(0)
# DEBUG. # DEBUG.
assert hidden_size == heads * hidden_size_per_head assert n_embed == heads * hidden_size_per_head
# Store the position embeddings. # Store the position embeddings.
output_state_dict["transformer.wpe.weight"] = pos_embeddings output_state_dict["transformer.wpe.weight"] = pos_embeddings
# The transformer. # The transformer.
transformer = lm["transformer"] transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]
# The regex to extract layer names. # The regex to extract layer names.
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
...@@ -92,6 +121,7 @@ def convert_megatron_checkpoint(args, input_state_dict): ...@@ -92,6 +121,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
# The simple map of names for "automated" rules. # The simple map of names for "automated" rules.
megatron_to_transformers = { megatron_to_transformers = {
"attention.dense": ".attn.c_proj.", "attention.dense": ".attn.c_proj.",
"self_attention.dense": ".attn.c_proj.",
"mlp.dense_h_to_4h": ".mlp.c_fc.", "mlp.dense_h_to_4h": ".mlp.c_fc.",
"mlp.dense_4h_to_h": ".mlp.c_proj.", "mlp.dense_4h_to_h": ".mlp.c_proj.",
} }
...@@ -122,26 +152,32 @@ def convert_megatron_checkpoint(args, input_state_dict): ...@@ -122,26 +152,32 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
# Transpose the QKV matrix. # Transpose the QKV matrix.
elif op_name == "attention.query_key_value" and weight_or_bias == "weight": elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "weight":
# Insert a tensor of 1x1xDxD bias. # Insert a tensor of 1x1xDxD bias.
zeros = torch.ones(1, 1, hidden_size, hidden_size) causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.uint8)).view(1, 1, n_embed, n_embed)
output_state_dict[layer_name + ".attn.bias"] = zeros output_state_dict[layer_name + ".attn.bias"] = causal_mask
# Insert a "dummy" tensor for masked_bias. # Insert a "dummy" tensor for masked_bias.
masked_bias = torch.tensor(-1e4) masked_bias = torch.tensor(-1e4)
output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias output_state_dict[layer_name + ".attn.masked_bias"] = masked_bias
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
# Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D.
out_val = val.transpose(0, 1) out_val = out_val.transpose(0, 1).contiguous()
# Store. # Store.
output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val output_state_dict[layer_name + ".attn.c_attn.weight"] = out_val
# Transpose the bias. # Transpose the bias.
elif op_name == "attention.query_key_value" and weight_or_bias == "bias": elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "bias":
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
# Store. No change of shape. # Store. No change of shape.
output_state_dict[layer_name + ".attn.c_attn.bias"] = val output_state_dict[layer_name + ".attn.c_attn.bias"] = out_val
# Transpose the weights. # Transpose the weights.
elif weight_or_bias == "weight": elif weight_or_bias == "weight":
...@@ -155,6 +191,9 @@ def convert_megatron_checkpoint(args, input_state_dict): ...@@ -155,6 +191,9 @@ def convert_megatron_checkpoint(args, input_state_dict):
out_name = megatron_to_transformers[op_name] out_name = megatron_to_transformers[op_name]
output_state_dict[layer_name + out_name + "bias"] = val output_state_dict[layer_name + out_name + "bias"] = val
# DEBUG.
assert config.n_layer == layer_idx + 1
# The final layernorm. # The final layernorm.
output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"] output_state_dict["transformer.ln_f.weight"] = transformer["final_layernorm.weight"]
output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"] output_state_dict["transformer.ln_f.bias"] = transformer["final_layernorm.bias"]
...@@ -162,33 +201,8 @@ def convert_megatron_checkpoint(args, input_state_dict): ...@@ -162,33 +201,8 @@ def convert_megatron_checkpoint(args, input_state_dict):
# For LM head, transformers' wants the matrix to weight embeddings. # For LM head, transformers' wants the matrix to weight embeddings.
output_state_dict["lm_head.weight"] = word_embeddings output_state_dict["lm_head.weight"] = word_embeddings
# The config.
output_config = {
"activation_function": "gelu_new",
"architectures": ["GPT2LMHeadModel"],
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 1024,
"n_embd": 1024,
"n_head": 16,
"n_layer": 24,
"n_positions": 1024,
"resid_pdrop": 0.1,
"summary_activation": None,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": True,
"summary_type": "cls_index",
"summary_use_proj": True,
"vocab_size": 50257,
}
# It should be done! # It should be done!
return output_state_dict, output_config return output_state_dict
#################################################################################################### ####################################################################################################
...@@ -198,21 +212,62 @@ def main(): ...@@ -198,21 +212,62 @@ def main():
# Create the argument parser. # Create the argument parser.
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--print-checkpoint-structure", action="store_true") parser.add_argument("--print-checkpoint-structure", action="store_true")
parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint") parser.add_argument(
"path_to_checkpoint",
type=str,
help="Path to the ZIP file containing the checkpoint",
)
parser.add_argument(
"--config_file",
default="",
type=str,
help="An optional config json file describing the pre-trained model.",
)
args = parser.parse_args() args = parser.parse_args()
# Extract the basename. # Extract the basename.
basename = os.path.dirname(args.path_to_checkpoint) basename = os.path.dirname(args.path_to_checkpoint)
# Load the model. # Load the model.
print('Extracting PyTorch state dictionary from "{}"'.format(args.path_to_checkpoint)) print(f"Extracting PyTorch state dictionary from {args.path_to_checkpoint}")
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu") input_state_dict = torch.load(pytorch_dict, map_location="cpu")
# Read the config, or default to the model released by NVIDIA.
if args.config_file == "":
# Spell out all parameters in case the defaults change.
config = GPT2Config(
vocab_size=50257,
n_positions=1024,
n_ctx=1024,
n_embd=1024,
n_layer=24,
n_head=16,
n_inner=4096,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
)
else:
config = GPT2Config.from_json_file(args.config_file)
# Convert. # Convert.
print("Converting") print("Converting")
output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict) output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
# Print the structure of converted state dict. # Print the structure of converted state dict.
if args.print_checkpoint_structure: if args.print_checkpoint_structure:
...@@ -220,6 +275,9 @@ def main(): ...@@ -220,6 +275,9 @@ def main():
# Store the config to file. # Store the config to file.
output_config_file = os.path.join(basename, "config.json") output_config_file = os.path.join(basename, "config.json")
output_config = config.to_dict()
output_config["architectures"] = ["GPT2LMHeadModel"]
output_config["model_type"] = "gpt2"
print(f'Saving config to "{output_config_file}"') print(f'Saving config to "{output_config_file}"')
with open(output_config_file, "w") as f: with open(output_config_file, "w") as f:
json.dump(output_config, f) json.dump(output_config, f)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import GPT2LMHeadModel
@require_torch
@require_sentencepiece
@require_tokenizers
class MegatronGPT2IntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
directory = "nvidia/megatron-gpt2-345m/"
if "MYDIR" in os.environ:
directory = os.path.join(os.environ["MYDIR"], directory)
model = GPT2LMHeadModel.from_pretrained(directory)
model.to(torch_device)
model.half()
input_ids = torch.tensor(
[[101, 7110, 1005, 1056, 2023, 11333, 17413, 1029, 102]],
device=torch_device,
dtype=torch.long,
)
with torch.no_grad():
output = model(input_ids).logits
expected_shape = torch.Size((1, 9, 50257))
self.assertEqual(output.shape, expected_shape)
expected_diag = torch.tensor(
[
4.9414,
-0.2920,
-1.2148,
-4.0273,
-0.5161,
-5.2109,
-1.2412,
-1.8301,
-1.7734,
-4.7148,
-0.2317,
-1.0811,
-2.1777,
0.4141,
-3.7969,
-4.0586,
-2.5332,
-3.3809,
4.3867,
],
device=torch_device,
dtype=torch.half,
)
for i in range(19):
r, c = 8 * i // 17, 2792 * i # along the diagonal
computed, expected = output[0, r, c], expected_diag[i]
msg = f"row={r} col={c} computed={computed} expected={expected}"
self.assertAlmostEqual(computed, expected, delta=1e-4, msg=msg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment