Commit ae856f3a authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Remove unnecessary files

parent 6ac8e63a
This diff is collapsed.
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r"^model.decoder.", "transformer.", key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r"^decoder.", "transformer.", key)
return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_emb(key):
key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
# The OPT-350m model uses has project_in and project_out
key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
key = re.sub(r"^transformer.project_out.", "project_out.", key)
key = re.sub(
r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
)
return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
)
key = re.sub(
r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
return re.sub(
r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(
r"^transformer.layers.(\d+).self_attn.out_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (
None
if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim
)
return GPT2Config(
vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings,
n_embd=opt_config.hidden_size,
n_layer=opt_config.num_hidden_layers,
n_head=opt_config.num_attention_heads,
n_inner=opt_config.ffn_dim,
activation_function=opt_config.activation_function,
resid_pdrop=opt_config.dropout,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop=opt_config.dropout,
attn_pdrop=opt_config.attention_dropout,
initializer_range=opt_config.init_std,
bos_token_id=opt_config.bos_token_id,
eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim,
)
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2022, Tri Dao.
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from flash_attn.utils.distributed import all_reduce, reduce_scatter
class GPT2Embeddings(nn.Module):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
padding_idx=None,
word_embed_proj_dim=None,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = None
else:
self.word_embeddings = nn.Embedding(
vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
)
self.project_in = nn.Linear(
word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim, **factory_kwargs
)
def forward(self, input_ids, position_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.project_in is not None:
embeddings = self.project_in(embeddings)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
return embeddings
class BertEmbeddings(nn.Module):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
type_vocab_size,
padding_idx=None,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.word_embeddings = nn.Embedding(
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
)
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(
max_position_embeddings, embed_dim, **factory_kwargs
)
if self.type_vocab_size > 0:
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if self.type_vocab_size > 0:
if token_type_ids is None:
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings
return embeddings
class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if num_embeddings % world_size != 0:
raise ValueError(
f"num_embeddings ({num_embeddings}) must be divisible by "
f"world_size ({world_size})"
)
if world_size > 1 and padding_idx is not None:
raise RuntimeError("ParallelEmbedding does not support padding_idx")
else:
world_size = 1
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
def forward(self, input: Tensor) -> Tensor:
if self.process_group is None:
return super().forward(input)
else:
rank = torch.distributed.get_rank(self.process_group)
vocab_size = self.num_embeddings
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
input = input - vocab_start_index
input[input_ids_mask] = 0
embeddings = super().forward(input)
embeddings[input_ids_mask] = 0.0
return embeddings
class ColumnParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
self.process_group = process_group
if process_group is not None:
world_size = torch.distributed.get_world_size(process_group)
if embedding_dim % world_size != 0:
raise ValueError(
f"embedding_dim ({embedding_dim}) must be divisible by "
f"world_size ({world_size})"
)
else:
world_size = 1
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
class ParallelGPT2Embeddings(nn.Module):
def __init__(
self,
embed_dim,
vocab_size,
max_position_embeddings,
process_group,
padding_idx=None,
sequence_parallel=True,
device=None,
dtype=None,
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding(
vocab_size,
embed_dim,
padding_idx=padding_idx,
process_group=process_group,
**factory_kwargs,
)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = ColumnParallelEmbedding(
max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
)
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size, seqlen = input_ids.shape
world_size = torch.distributed.get_world_size(self.process_group)
embeddings = self.word_embeddings(input_ids)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
position_embeddings = self.position_embeddings(position_ids)
if world_size <= 1:
embeddings = embeddings + position_embeddings
else:
partition_dim = self.position_embeddings.embedding_dim
rank = torch.distributed.get_rank(self.process_group)
embeddings[
..., rank * partition_dim : (rank + 1) * partition_dim
] += position_embeddings
if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, "b s d -> (b s) d")
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment