Commit 96d10f65 authored by Tri Dao's avatar Tri Dao
Browse files

Implement LLaMa

parent b630aef5
...@@ -487,18 +487,14 @@ def remap_state_dict(state_dict, config): ...@@ -487,18 +487,14 @@ def remap_state_dict(state_dict, config):
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat( state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0 [Wq, Wk, Wv], dim=0
) )
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat( state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
[bq, bk, bv], dim=0
)
else: else:
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat( state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
[Wk, Wv], dim=0 [Wk, Wv], dim=0
) )
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat( state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
[bk, bv], dim=0
)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)', return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mixer.out_proj.\2', key) r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
......
...@@ -43,6 +43,16 @@ try: ...@@ -43,6 +43,16 @@ try:
except ImportError: except ImportError:
dropout_add_layer_norm_parallel_residual = None dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
try: try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError: except ImportError:
...@@ -90,6 +100,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -90,6 +100,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True)
mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True)
fused_mlp = getattr(config, 'fused_mlp', False) fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp: if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
...@@ -108,7 +120,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp ...@@ -108,7 +120,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
else (F.silu if config.activation_function == 'swiglu' else (F.silu if config.activation_function == 'swiglu'
else F.gelu)) else F.gelu))
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation, mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
**factory_kwargs) bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
else: else:
if config.activation_function == 'relu': if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True) activation = partial(F.relu, inplace=True)
...@@ -119,7 +131,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp ...@@ -119,7 +131,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate) activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation, mlp_cls = partial(Mlp, hidden_features=config.n_inner, activation=activation,
**factory_kwargs) bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
else: else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...@@ -137,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp ...@@ -137,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if process_group is not None else {}) if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl, checkpoint_lvl=mlp_checkpoint_lvl,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs) **parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense: elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None assert FusedDenseSqreluDense is not None
...@@ -152,7 +165,9 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype= ...@@ -152,7 +165,9 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
sequence_parallel = getattr(config, 'sequence_parallel', True) sequence_parallel = getattr(config, 'sequence_parallel', True)
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs) use_rms_norm = getattr(config, 'rms_norm', False)
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
eps=config.layer_norm_epsilon, **factory_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32 = getattr(config, 'residual_in_fp32', False) residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
...@@ -267,6 +282,7 @@ class GPTModel(GPTPreTrainedModel): ...@@ -267,6 +282,7 @@ class GPTModel(GPTPreTrainedModel):
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False) self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
# These 2 options are for OPT-350m # These 2 options are for OPT-350m
self.prenorm = getattr(config, 'prenorm', True) self.prenorm = getattr(config, 'prenorm', True)
use_rms_norm = getattr(config, 'rms_norm', False)
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
# For GPT-J, GPT-NeoX # For GPT-J, GPT-NeoX
self.parallel_block = getattr(config, 'parallel_block', False) self.parallel_block = getattr(config, 'parallel_block', False)
...@@ -300,7 +316,8 @@ class GPTModel(GPTPreTrainedModel): ...@@ -300,7 +316,8 @@ class GPTModel(GPTPreTrainedModel):
raise ImportError('dropout_layer_norm is not installed') raise ImportError('dropout_layer_norm is not installed')
if self.prenorm: if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop) self.drop_f = nn.Dropout(config.resid_pdrop)
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs) **factory_kwargs)
if process_group is not None: if process_group is not None:
for p in self.ln_f.parameters(): for p in self.ln_f.parameters():
...@@ -512,30 +529,39 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -512,30 +529,39 @@ def combine_state_dicts_tp(state_dicts, config):
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0 assert inner_dim % world_size == 0
# The word embeddings from Megatron are weird, for each shard only the first # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
# vocab_size // world_size coordinates are nonzero. # vocab_size // world_size coordinates are nonzero.
def combine_word_embeddings(state_dicts, state_dict, key): def combine_word_embeddings(state_dicts, state_dict, key):
assert all(s[key].shape[0] == vocab_size for s in state_dicts) dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
state_dict[key] = torch.cat([s[key][:vocab_size // world_size] for s in state_dicts], dim=0) state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_dim(state_dicts, state_dict, key, dim=-1): def combine_dim(state_dicts, state_dict, key, dim=-1):
if key in state_dict:
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_qkv_headdim(state_dicts, state_dict, key): def combine_qkv_headdim(state_dicts, state_dict, key):
if key in state_dict:
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
def combine_gated_mlp(state_dicts, state_dict, key):
if key in state_dict:
xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...')
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight') combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict: if 'lm_head.weight' in state_dict:
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight') combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict: if 'transformer.embeddings.position_embeddings.weight' in state_dict:
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1) combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu']
else partial(combine_dim, dim=0))
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1) combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight', 0) mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0) combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1) combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
return state_dict return state_dict
...@@ -603,7 +629,8 @@ def remap_state_dict_megatron(state_dict, config): ...@@ -603,7 +629,8 @@ def remap_state_dict_megatron(state_dict, config):
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight') word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
......
...@@ -56,9 +56,7 @@ def remap_state_dict_hf_gptj(state_dict, config): ...@@ -56,9 +56,7 @@ def remap_state_dict_hf_gptj(state_dict, config):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight') Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight') Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight') Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat( state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
[Wq, Wk, Wv], dim=0
)
# We don't store these biases # We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias') state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias') state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
......
# Copyright (c) 2023, Tri Dao.
import math
import json
import re
from pathlib import Path
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('output.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
# Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
num_attention_heads=params['n_heads'],
num_hidden_layers=params['n_layers'],
rms_norm_eps=params['norm_eps'])
return config
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
return GPT2Config(
vocab_size=llama_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=llama_config.hidden_size,
n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=llama_config.rms_norm_eps,
initializer_range=llama_config.initializer_range,
bos_token_id=llama_config.bos_token_id,
eos_token_id=llama_config.eos_token_id,
# These are new arguments not in the original GPT2Config
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
rms_norm=True,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
mlp_fc1_bias=False,
mlp_fc2_bias=False,
)
...@@ -66,12 +66,8 @@ def remap_state_dict_hf_opt(state_dict, config): ...@@ -66,12 +66,8 @@ def remap_state_dict_hf_opt(state_dict, config):
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias') bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias') bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias') bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat( state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
[Wq, Wk, Wv], dim=0 state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat(
[bq, bk, bv], dim=0
)
def key_mapping_attn(key): def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.', return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key) r'transformer.layers.\1.mixer.out_proj.', key)
......
...@@ -23,6 +23,16 @@ try: ...@@ -23,6 +23,16 @@ try:
except ImportError: except ImportError:
dropout_add_layer_norm_parallel_residual = None dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
class Block(nn.Module): class Block(nn.Module):
...@@ -70,7 +80,9 @@ class Block(nn.Module): ...@@ -70,7 +80,9 @@ class Block(nn.Module):
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed' assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed'
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
and isinstance(self.dropout1, nn.Dropout))
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different. # then the input to each worker in the tensor parallel group will be different.
...@@ -104,6 +116,8 @@ class Block(nn.Module): ...@@ -104,6 +116,8 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer. about the CLS token in the last layer.
""" """
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm)
if self.prenorm: if self.prenorm:
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states)) dropped = self.drop_path1(self.dropout1(hidden_states))
...@@ -119,7 +133,7 @@ class Block(nn.Module): ...@@ -119,7 +133,7 @@ class Block(nn.Module):
hidden_states.shape[:-1], device=hidden_states.device, hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
) )
hidden_states, residual = dropout_add_layer_norm( hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm1.weight, self.norm1.bias, hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps, self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32 rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
...@@ -146,7 +160,7 @@ class Block(nn.Module): ...@@ -146,7 +160,7 @@ class Block(nn.Module):
hidden_states.shape[:-1], device=hidden_states.device, hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
) )
hidden_states, residual = dropout_add_layer_norm( hidden_states, residual = fused_add_norm_fn(
hidden_states, residual, self.norm2.weight, self.norm2.bias, hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps, self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32 rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
...@@ -170,7 +184,7 @@ class Block(nn.Module): ...@@ -170,7 +184,7 @@ class Block(nn.Module):
rowscale1 = self.drop_path1(torch.ones( rowscale1 = self.drop_path1(torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
) )
hidden_states = dropout_add_layer_norm( hidden_states = fused_add_norm_fn(
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias, mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps, self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=False rowscale=rowscale1, prenorm=False
...@@ -189,7 +203,7 @@ class Block(nn.Module): ...@@ -189,7 +203,7 @@ class Block(nn.Module):
rowscale2 = self.drop_path2(torch.ones( rowscale2 = self.drop_path2(torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
) )
hidden_states = dropout_add_layer_norm( hidden_states = fused_add_norm_fn(
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias, mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps, self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=False rowscale=rowscale2, prenorm=False
...@@ -234,7 +248,9 @@ class ParallelBlock(nn.Module): ...@@ -234,7 +248,9 @@ class ParallelBlock(nn.Module):
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed' assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
and isinstance(self.dropout1, nn.Dropout))
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different. # then the input to each worker in the tensor parallel group will be different.
...@@ -266,6 +282,9 @@ class ParallelBlock(nn.Module): ...@@ -266,6 +282,9 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual. residual.
""" """
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual)
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1) dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts # For the very 1st block, we only want 1 dropout, not two different dropouts
...@@ -283,7 +302,7 @@ class ParallelBlock(nn.Module): ...@@ -283,7 +302,7 @@ class ParallelBlock(nn.Module):
else: else:
weight2, bias2 = ((self.norm2.weight, self.norm2.bias) weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
if not self.tied_norm else (None, None)) if not self.tied_norm else (None, None))
hidden_states1, hidden_states2, residual = dropout_add_layer_norm_parallel_residual( hidden_states1, hidden_states2, residual = fused_add_norm_fn(
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias, hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps, weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
prenorm=True, residual_in_fp32=self.residual_in_fp32 prenorm=True, residual_in_fp32=self.residual_in_fp32
......
...@@ -13,15 +13,15 @@ except ImportError: ...@@ -13,15 +13,15 @@ except ImportError:
class Mlp(nn.Module): class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
return_residual=False, device=None, dtype=None): bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features * 4 hidden_features = hidden_features or in_features * 4
self.return_residual = return_residual self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.activation = activation self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x): def forward(self, x):
y = self.fc1(x) y = self.fc1(x)
...@@ -33,16 +33,17 @@ class Mlp(nn.Module): ...@@ -33,16 +33,17 @@ class Mlp(nn.Module):
class GatedMlp(nn.Module): class GatedMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid, def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
multiple_of=128, return_residual=False, device=None, dtype=None): bias1=True, bias2=True, multiple_of=256, return_residual=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or int(8 * in_features / 3) hidden_features = hidden_features or int(8 * in_features / 3)
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.return_residual = return_residual self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, 2 * hidden_features, **factory_kwargs) self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
self.activation = activation self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs)
def forward(self, x): def forward(self, x):
y = self.fc1(x) y = self.fc1(x)
......
...@@ -351,7 +351,7 @@ class DropoutAddLayerNorm(torch.nn.Module): ...@@ -351,7 +351,7 @@ class DropoutAddLayerNorm(torch.nn.Module):
super().__init__() super().__init__()
self.prenorm = prenorm self.prenorm = prenorm
self.p = p self.p = p
self.epsilon = eps self.eps = eps
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
...@@ -363,5 +363,5 @@ class DropoutAddLayerNorm(torch.nn.Module): ...@@ -363,5 +363,5 @@ class DropoutAddLayerNorm(torch.nn.Module):
def forward(self, x0, residual=None): def forward(self, x0, residual=None):
return dropout_add_layer_norm(x0, residual, self.weight, self.bias, return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
self.p if self.training else 0.0, self.epsilon, self.p if self.training else 0.0, self.eps,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
...@@ -51,6 +51,22 @@ def dropout_add_rms_norm_parallel_residual( ...@@ -51,6 +51,22 @@ def dropout_add_rms_norm_parallel_residual(
) )
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x):
return rms_norm(x, self.weight, self.eps)
class DropoutAddRMSNorm(torch.nn.Module): class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None): device=None, dtype=None):
...@@ -58,7 +74,7 @@ class DropoutAddRMSNorm(torch.nn.Module): ...@@ -58,7 +74,7 @@ class DropoutAddRMSNorm(torch.nn.Module):
super().__init__() super().__init__()
self.prenorm = prenorm self.prenorm = prenorm
self.p = p self.p = p
self.epsilon = eps self.eps = eps
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None) self.register_parameter('bias', None)
...@@ -69,5 +85,5 @@ class DropoutAddRMSNorm(torch.nn.Module): ...@@ -69,5 +85,5 @@ class DropoutAddRMSNorm(torch.nn.Module):
def forward(self, x0, residual=None): def forward(self, x0, residual=None):
return dropout_add_rms_norm(x0, residual, self.weight, None, return dropout_add_rms_norm(x0, residual, self.weight, None,
self.p if self.training else 0.0, self.epsilon, self.p if self.training else 0.0, self.eps,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
...@@ -105,7 +105,7 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply ...@@ -105,7 +105,7 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class FusedDenseSqreluDense(nn.Module): class FusedDenseSqreluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True, def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True,
checkpoint_lvl=0, device=None, dtype=None): checkpoint_lvl=0, device=None, dtype=None):
""" """
checkpoint_lvl (increasing lvl means slower but more memory saving): checkpoint_lvl (increasing lvl means slower but more memory saving):
...@@ -117,11 +117,12 @@ class FusedDenseSqreluDense(nn.Module): ...@@ -117,11 +117,12 @@ class FusedDenseSqreluDense(nn.Module):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features * 4
assert bias == True, "DenseSqreluDense module without bias is currently not supported" assert bias1 == True, "DenseSqreluDense module without bias is currently not supported"
assert bias2 == True, "DenseSqreluDense module without bias is currently not supported"
self.checkpoint_lvl = checkpoint_lvl self.checkpoint_lvl = checkpoint_lvl
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x): def forward(self, x):
assert x.is_cuda assert x.is_cuda
......
# Copyright (c) 2023, Tri Dao.
import time import time
import torch import torch
......
# Copyright (c) 2023, Tri Dao.
import time import time
import torch import torch
......
# Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR$/llama/7B-hf
# and repeat for 13B, 30B, 65B
import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import torch
import pytest
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('model_name', ["7B"])
def test_llama_state_dict(model_name):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
state_dict = model.state_dict()
rotary_inv_freq_keys = {f'transformer.layers.{l}.mixer.rotary_emb.inv_freq'
for l in range(config.n_layer)}
assert state_dict.keys() == pretrained_state_dict.keys() | rotary_inv_freq_keys
for k in state_dict.keys() - rotary_inv_freq_keys:
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["7B", "13B"])
def test_llama_optimized(model_name):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16
device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict, strict=False)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
device_map='auto')
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
torch_dtype=dtype, device_map={"": device})
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.skip(reason="Tensor Parallel is not implemented for GatedMLP yet")
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
def test_llama_parallel(model_name, world_size):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from apex.transformer import parallel_state
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16
device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank),
strict=False)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
device_map='auto')
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
torch_dtype=dtype, device_map="auto")
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
@pytest.mark.parametrize('model_name', ["7B"])
def test_llama_generation(model_name):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
dtype = torch.float16
device = 'cuda'
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
torch_dtype=dtype, device_map={"": device})
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
del model_hf
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
device_map={"": device})
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
del model_ref
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict, strict=False)
model.eval()
print('Without CUDA graph')
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=True, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
# For some reason logits_parallel is off by quite a bit more than 2x
assert (logits_parallel - logits_ref).abs().max().item() < 8 * hf_error
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_parallel).abs().max().item() }')
assert (logits - logits_parallel).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_parallel).abs().max().item() }')
assert torch.equal(logits_cg, logits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment