Commit 4d87e4d8 authored by Tri Dao's avatar Tri Dao
Browse files

Implement GPT-J

parent 4360cfc6
......@@ -18,12 +18,13 @@ from einops import rearrange
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_opt
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.models.gptj import remap_state_dict_hf_gptj
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
......@@ -36,9 +37,10 @@ except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense, sqrelu_fwd
except ImportError:
FusedDenseSqreluDense = None
sqrelu_fwd = None
logger = logging.getLogger(__name__)
......@@ -54,8 +56,11 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
dwconv = getattr(config, 'attn_dwconv', False)
if dwconv:
assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
out_proj_bias = getattr(config, 'out_proj_bias', True)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if not fused_bias_fc:
......@@ -66,9 +71,12 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
return mixer_cls
......@@ -88,8 +96,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if process_group is not None:
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
if not fused_mlp and not fused_dense_sqrelu_dense:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
elif config.activation_function == 'sqrelu':
assert sqrelu_fwd is not None, 'sqrelu_fwd is not implemented'
activation = sqrelu_fwd
else:
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
......@@ -132,12 +144,27 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
prenorm = getattr(config, 'prenorm', True)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
parallel_block = getattr(config, 'parallel_block', False)
if not parallel_block:
block = Block(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None)
mark_shared_params=process_group is not None
)
else:
assert prenorm
block = ParallelBlock(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
tied_norm=getattr(config, 'parallel_block_tied_norm', False),
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None
)
block.layer_idx = layer_idx
return block
......@@ -172,9 +199,12 @@ class GPTPreTrainedModel(nn.Module):
model_name, device='cpu', dtype=dtype
)
if model_name.startswith('gpt2'):
state_dict = remap_state_dict_gpt2(state_dict, config)
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
elif model_name.startswith('facebook/opt'):
state_dict = remap_state_dict_opt(state_dict, config)
state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-j-'):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
strict = False # We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
else:
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:
......@@ -223,6 +253,8 @@ class GPTModel(GPTPreTrainedModel):
# These 2 options are for OPT-350m
self.prenorm = getattr(config, 'prenorm', True)
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
# For GPT-J, GPT-NeoX
self.parallel_block = getattr(config, 'parallel_block', False)
if process_group is None:
self.embeddings = GPT2Embeddings(
......@@ -276,6 +308,8 @@ class GPTModel(GPTPreTrainedModel):
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None and self.sequence_parallel else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
if self.parallel_block:
hidden_states2 = None
residual = None
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel else {})
......@@ -283,15 +317,27 @@ class GPTModel(GPTPreTrainedModel):
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers:
if self.prenorm:
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
if not self.parallel_block:
hidden_states, residual = layer(hidden_states, residual,
mixer_kwargs=mixer_kwargs)
else:
hidden_states, hidden_states2, residual = layer(
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
)
else:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_f(hidden_states)
if not self.parallel_block:
residual = (dropped + residual) if residual is not None else dropped
else:
dropped2 = self.drop_f(hidden_states2)
residual = ((residual + dropped + dropped2)
if residual is not None else dropped + dropped2)
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
assert not self.parallel_block
# Set prenorm=False here since we don't need the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
......@@ -308,6 +354,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
super().__init__(config)
self.process_group = process_group
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
......@@ -319,12 +366,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
else:
self.project_out = None
if process_group is None:
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=not self.tie_word_embeddings,
**factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(
embed_dim, vocab_size, process_group, bias=False,
embed_dim, vocab_size, process_group, bias=not self.tie_word_embeddings,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
......@@ -333,6 +381,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self.tie_weights()
def tie_weights(self):
if self.tie_word_embeddings:
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_shared_params(self, self.process_group)
......@@ -381,7 +430,95 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
return super().load_state_dict(state_dict, strict=strict)
def remap_state_dict_gpt2(state_dict, config):
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
def shard_first_dim(state_dict, key):
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
def shard_last_dim(state_dict, key):
x = state_dict[key]
dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
def shard_qkv_headdim(state_dict, key):
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
shard_first_dim(state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
return state_dict
def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
# The word embeddings from Megatron are weird, for each shard only the first
# vocab_size // world_size coordinates are nonzero.
def combine_word_embeddings(state_dicts, state_dict, key):
assert all(s[key].shape[0] == vocab_size for s in state_dicts)
state_dict[key] = torch.cat([s[key][:vocab_size // world_size] for s in state_dicts], dim=0)
def combine_dim(state_dicts, state_dict, key, dim=-1):
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_qkv_headdim(state_dicts, state_dict, key):
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
for i in range(config.num_hidden_layers):
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight', 0)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
return state_dict
def remap_state_dict_hf_gpt2(state_dict, config):
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
......@@ -430,47 +567,67 @@ def remap_state_dict_gpt2(state_dict, config):
return state_dict
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
def remap_state_dict_megatron(state_dict, config):
def key_mapping_transformer(key):
key = re.sub(r'^language_model.encoder.', 'transformer.', key)
key = re.sub(r'^language_model.', 'transformer.', key)
return key
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
def shard_first_dim(state_dict, key):
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
r'transformer.layers.\1.norm1.\2', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
r'transformer.layers.\1.norm2.\2', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def shard_last_dim(state_dict, key):
x = state_dict[key]
dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
r'transformer.layers.\1.mlp.fc1.\2', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
r'transformer.layers.\1.mlp.fc2.\2', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def shard_qkv_headdim(state_dict, key):
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
# Attention
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
r'transformer.layers.\1.mixer.Wqkv.\2', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
r'transformer.layers.\1.mixer.out_proj.\2', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
for d in range(config.num_hidden_layers):
Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
)
bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
)
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
shard_first_dim(state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias')
return state_dict
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
headdim = gptj_config.n_embd // gptj_config.n_head
return GPT2Config(
vocab_size=gptj_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gptj_config.n_embd,
n_layer=gptj_config.n_layer,
n_head=gptj_config.n_head,
n_inner=gptj_config.n_inner,
activation_function=gptj_config.activation_function,
resid_pdrop=gptj_config.resid_pdrop,
embd_pdrop=gptj_config.embd_pdrop,
attn_pdrop=gptj_config.attn_pdrop,
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
initializer_range=gptj_config.initializer_range,
bos_token_id=gptj_config.bos_token_id,
eos_token_id=gptj_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=True,
parallel_block_tied_norm=True,
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
)
......@@ -11,7 +11,7 @@ import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_opt(state_dict, config):
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
......
......@@ -190,3 +190,93 @@ class Block(nn.Module):
rowscale=rowscale2, prenorm=False
)
return hidden_states
class ParallelBlock(nn.Module):
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
and PaLM.
"""
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
sequence_parallel=False, mark_shared_params=False):
"""
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
the hidden_states (output1 of the MHA / MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.tied_norm = tied_norm
self.fused_dropout_add_ln = fused_dropout_add_ln
assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet'
self.residual_in_fp32 = residual_in_fp32
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
self.dropout2 = dropout_cls(resid_dropout2)
if not self.tied_norm:
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._shared_params = True
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
residual: Optional[Tensor] = None, mixer_kwargs=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states1: the output of the previous attention (mixer) or embedding layer.
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
if mixer_kwargs is None:
mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
hidden_states2 = self.mlp(hidden_states2)
return hidden_states1, hidden_states2, residual
......@@ -347,9 +347,10 @@ class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0,
rotary_emb_scale_base=0,
def __init__(self, embed_dim, num_heads, cross_attn=False,
qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
checkpointing=False, device=None, dtype=None) -> None:
"""
......@@ -377,7 +378,7 @@ class MHA(nn.Module):
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
device=device)
interleaved=rotary_emb_interleaved, device=device)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
......@@ -388,18 +389,22 @@ class MHA(nn.Module):
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
if not self.cross_attn:
if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
else:
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
if self.dwconv:
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
else:
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
else:
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
self.Wkv = linear_resid_cls(embed_dim, 2 * embed_dim, bias=qkv_proj_bias,
**factory_kwargs)
if self.dwconv:
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
groups=embed_dim)
......@@ -409,8 +414,7 @@ class MHA(nn.Module):
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
......@@ -526,9 +530,10 @@ class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False,
def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
use_flash_attn=False, checkpointing=False,
sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
......@@ -546,11 +551,12 @@ class ParallelMHA(nn.Module):
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
device=device)
interleaved=rotary_emb_interleaved, device=device)
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
......@@ -558,8 +564,8 @@ class ParallelMHA(nn.Module):
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
bias=out_proj_bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
......
......@@ -71,8 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
eos_token_id=None, vocab_size=None, tensor_parallel=1, fused_ft_kernel=False,
cg=False, timing=False):
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
fused_ft_kernel=False, cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
......@@ -87,6 +87,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
assert fused_ft_kernel
if not hasattr(model, '_decoding_cache'):
......@@ -111,7 +112,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, seqlen_og]
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
......@@ -126,7 +130,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
next_token = sample(logits, top_k=top_k, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if eos_token_id is not None and (next_token == eos_token_id).all():
......
......@@ -7,7 +7,7 @@ from transformers import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
......@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name)
pretrained_state_dict = remap_state_dict_gpt2(state_dict_from_pretrained(model_name), config)
pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
......
......@@ -12,8 +12,8 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
......
......@@ -12,7 +12,7 @@ from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
......
import re
import torch
import pytest
from transformers import GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"])
def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = False # We don't support parallel block yet
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
model_ref = GPTJForCausalLM.from_pretrained(model_name).to(device=device)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
......@@ -7,7 +7,7 @@ from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
......@@ -15,7 +15,7 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_opt(state_dict_from_pretrained(model_name), config)
pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment