Commit 4d87e4d8 authored by Tri Dao's avatar Tri Dao
Browse files

Implement GPT-J

parent 4360cfc6
This diff is collapsed.
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
headdim = gptj_config.n_embd // gptj_config.n_head
return GPT2Config(
vocab_size=gptj_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gptj_config.n_embd,
n_layer=gptj_config.n_layer,
n_head=gptj_config.n_head,
n_inner=gptj_config.n_inner,
activation_function=gptj_config.activation_function,
resid_pdrop=gptj_config.resid_pdrop,
embd_pdrop=gptj_config.embd_pdrop,
attn_pdrop=gptj_config.attn_pdrop,
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
initializer_range=gptj_config.initializer_range,
bos_token_id=gptj_config.bos_token_id,
eos_token_id=gptj_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=True,
parallel_block_tied_norm=True,
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
)
......@@ -11,7 +11,7 @@ import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_opt(state_dict, config):
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
......
......@@ -190,3 +190,93 @@ class Block(nn.Module):
rowscale=rowscale2, prenorm=False
)
return hidden_states
class ParallelBlock(nn.Module):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
sequence_parallel=False, mark_shared_params=False):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.tied_norm = tied_norm
self.fused_dropout_add_ln = fused_dropout_add_ln
assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet'
self.residual_in_fp32 = residual_in_fp32
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
self.dropout2 = dropout_cls(resid_dropout2)
if not self.tied_norm:
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._shared_params = True
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
if mixer_kwargs is None:
mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
hidden_states2 = self.mlp(hidden_states2)
return hidden_states1, hidden_states2, residual
......@@ -347,9 +347,10 @@ class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0,
rotary_emb_scale_base=0,
def __init__(self, embed_dim, num_heads, cross_attn=False,
qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
checkpointing=False, device=None, dtype=None) -> None:
"""
......@@ -377,7 +378,7 @@ class MHA(nn.Module):
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
device=device)
interleaved=rotary_emb_interleaved, device=device)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
......@@ -388,29 +389,32 @@ class MHA(nn.Module):
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
if not self.cross_attn:
if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
else:
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
else:
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
if self.dwconv:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim)
groups=embed_dim)
self.dwconv_kv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, kernel_size=3, padding=2,
groups=2 * embed_dim)
groups=2 * embed_dim)
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
......@@ -526,9 +530,10 @@ class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False,
def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
use_flash_attn=False, checkpointing=False,
sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
......@@ -546,11 +551,12 @@ class ParallelMHA(nn.Module):
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
device=device)
interleaved=rotary_emb_interleaved, device=device)
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
......@@ -558,8 +564,8 @@ class ParallelMHA(nn.Module):
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
bias=out_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
......
......@@ -71,8 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
eos_token_id=None, vocab_size=None, tensor_parallel=1, fused_ft_kernel=False,
cg=False, timing=False):
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
fused_ft_kernel=False, cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
......@@ -87,6 +87,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
assert fused_ft_kernel
if not hasattr(model, '_decoding_cache'):
......@@ -111,7 +112,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, seqlen_og]
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
......@@ -126,7 +130,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
next_token = sample(logits, top_k=top_k, temperature=temperature)
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
next_token = sample(logits, top_k=top_k, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if eos_token_id is not None and (next_token == eos_token_id).all():
......
......@@ -7,7 +7,7 @@ from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
......@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name)
pretrained_state_dict = remap_state_dict_gpt2(state_dict_from_pretrained(model_name), config)
pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
......
......@@ -12,8 +12,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
......
......@@ -12,7 +12,7 @@ from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
......
import re
import torch
import pytest
from transformers import GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = False # We don't support parallel block yet
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
model_ref = GPTJForCausalLM.from_pretrained(model_name).to(device=device)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
......@@ -7,7 +7,7 @@ from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
......@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_opt(state_dict_from_pretrained(model_name), config)
pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment