Commit 4b661a56 authored by Tri Dao's avatar Tri Dao
Browse files

[GPT] Run black on gpt.py

parent bec5b3d3
...@@ -3,32 +3,34 @@ ...@@ -3,32 +3,34 @@
import logging import logging
import math import math
import re import re
from functools import partial from collections import OrderedDict, namedtuple
from collections import namedtuple, OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import GPT2Config
from einops import rearrange from einops import rearrange
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.ops.activations import sqrelu_fwd from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.modules.mlp import Mlp, ParallelMLP, FusedMLP, ParallelFusedMLP from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp
from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.modules.mlp import (
FusedMLP,
GatedMlp,
Mlp,
ParallelFusedMLP,
ParallelGatedMlp,
ParallelMLP,
)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params
from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_hf_opt from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.models.gptj import remap_state_dict_hf_gptj from transformers import GPT2Config
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.falcon import remap_state_dict_hf_falcon
try: try:
from flash_attn.ops.fused_dense import ColumnParallelLinear from flash_attn.ops.fused_dense import ColumnParallelLinear
...@@ -65,158 +67,247 @@ logger = logging.getLogger(__name__) ...@@ -65,158 +67,247 @@ logger = logging.getLogger(__name__)
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5) softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
if config.scale_attn_by_inverse_layer_idx: if config.scale_attn_by_inverse_layer_idx:
assert layer_idx is not None assert layer_idx is not None
softmax_scale /= float(layer_idx + 1) softmax_scale /= float(layer_idx + 1)
dwconv = getattr(config, 'attn_dwconv', False) dwconv = getattr(config, "attn_dwconv", False)
if dwconv: if dwconv:
assert process_group is None, 'TensorParallel MHA does not support dwconv yet' assert process_group is None, "TensorParallel MHA does not support dwconv yet"
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True) qkv_proj_bias = getattr(config, "qkv_proj_bias", True)
out_proj_bias = getattr(config, 'out_proj_bias', True) out_proj_bias = getattr(config, "out_proj_bias", True)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim)
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0) rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None) rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None)
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False) rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False)
use_flash_attn = getattr(config, 'use_flash_attn', False) use_flash_attn = getattr(config, "use_flash_attn", False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, "fused_bias_fc", False)
if not fused_bias_fc: if not fused_bias_fc:
assert process_group is None, 'TensorParallel MHA requires fused_bias_fc' assert process_group is None, "TensorParallel MHA requires fused_bias_fc"
mha_cls = MHA if process_group is None else ParallelMHA mha_cls = MHA if process_group is None else ParallelMHA
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv} serial_kwargs = (
if process_group is None else {}) {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {}
parallel_kwargs = ({'process_group': process_group, )
'sequence_parallel': getattr(config, 'sequence_parallel', True)} parallel_kwargs = (
if process_group is not None else {}) {
"process_group": process_group,
"sequence_parallel": getattr(config, "sequence_parallel", True),
}
if process_group is not None
else {}
)
num_heads_kv = getattr(config, "n_head_kv", None) num_heads_kv = getattr(config, "n_head_kv", None)
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, mixer_cls = partial(
num_heads_kv=num_heads_kv, mha_cls,
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, num_heads=config.num_attention_heads,
dropout=config.attn_pdrop, num_heads_kv=num_heads_kv,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, qkv_proj_bias=qkv_proj_bias,
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base, out_proj_bias=out_proj_bias,
rotary_emb_scale_base=rotary_emb_scale_base, dropout=config.attn_pdrop,
rotary_emb_interleaved=rotary_emb_interleaved, softmax_scale=softmax_scale,
use_flash_attn=use_flash_attn, causal=True,
**serial_kwargs, **parallel_kwargs, **factory_kwargs) layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim,
rotary_emb_base=rotary_emb_base,
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_flash_attn=use_flash_attn,
**serial_kwargs,
**parallel_kwargs,
**factory_kwargs,
)
return mixer_cls return mixer_cls
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True) mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True)
mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True) mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True)
fused_mlp = getattr(config, 'fused_mlp', False) fused_mlp = getattr(config, "fused_mlp", False)
if fused_mlp: if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] assert config.activation_function in [
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) "gelu_new",
"gelu_fast",
"gelu_approx",
"relu",
"sqrelu",
]
fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False)
if fused_dense_sqrelu_dense: if fused_dense_sqrelu_dense:
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' assert config.activation_function == "sqrelu", (
'supports approximate activation_function sqrelu') "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu"
)
assert not (fused_dense_sqrelu_dense and fused_mlp) assert not (fused_dense_sqrelu_dense and fused_mlp)
if not fused_mlp and not fused_dense_sqrelu_dense: if not fused_mlp and not fused_dense_sqrelu_dense:
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu', assert config.activation_function in [
'sqrelu', 'glu', 'swiglu', 'geglu'] "gelu",
if config.activation_function in ['glu', 'swiglu', 'geglu']: "gelu_new",
activation = (F.sigmoid if config.activation_function == 'glu' "gelu_fast",
else (F.silu if config.activation_function == 'swiglu' "gelu_approx",
else F.gelu)) "relu",
"sqrelu",
"glu",
"swiglu",
"geglu",
]
if config.activation_function in ["glu", "swiglu", "geglu"]:
activation = (
F.sigmoid
if config.activation_function == "glu"
else (F.silu if config.activation_function == "swiglu" else F.gelu)
)
mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp
parallel_kwargs = ({'process_group': process_group, parallel_kwargs = (
'sequence_parallel': getattr(config, 'sequence_parallel', True)} {
if process_group is not None else {}) "process_group": process_group,
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, "sequence_parallel": getattr(config, "sequence_parallel", True),
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, }
**parallel_kwargs, **factory_kwargs) if process_group is not None
else {}
)
mlp_cls = partial(
mlp_cls,
hidden_features=config.n_inner,
activation=activation,
bias1=mlp_fc1_bias,
bias2=mlp_fc2_bias,
**parallel_kwargs,
**factory_kwargs,
)
else: else:
if config.activation_function == 'relu': if config.activation_function == "relu":
activation = partial(F.relu, inplace=True) activation = partial(F.relu, inplace=True)
elif config.activation_function == 'sqrelu': elif config.activation_function == "sqrelu":
activation = sqrelu_fwd activation = sqrelu_fwd
else: else:
approximate = ('tanh' if config.activation_function approximate = (
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') "tanh"
activation=partial(F.gelu, approximate=approximate) if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
else "none"
)
activation = partial(F.gelu, approximate=approximate)
mlp_cls = Mlp if process_group is None else ParallelMLP mlp_cls = Mlp if process_group is None else ParallelMLP
parallel_kwargs = ({'process_group': process_group, parallel_kwargs = (
'sequence_parallel': getattr(config, 'sequence_parallel', True)} {
if process_group is not None else {}) "process_group": process_group,
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, "sequence_parallel": getattr(config, "sequence_parallel", True),
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, }
**parallel_kwargs, **factory_kwargs) if process_group is not None
else {}
)
mlp_cls = partial(
mlp_cls,
hidden_features=config.n_inner,
activation=activation,
bias1=mlp_fc1_bias,
bias2=mlp_fc2_bias,
**parallel_kwargs,
**factory_kwargs,
)
else: else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence): if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_mlp: if fused_mlp:
if FusedMLP is None: if FusedMLP is None:
raise ImportError('fused_dense is not installed') raise ImportError("fused_dense is not installed")
activation = ('gelu_approx' if config.activation_function activation = (
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function) "gelu_approx"
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
else config.activation_function
)
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = ({'process_group': process_group, parallel_kwargs = (
'sequence_parallel': getattr(config, 'sequence_parallel', True)} {
if process_group is not None else {}) "process_group": process_group,
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, "sequence_parallel": getattr(config, "sequence_parallel", True),
checkpoint_lvl=mlp_checkpoint_lvl, }
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, if process_group is not None
**parallel_kwargs, **factory_kwargs) else {}
)
mlp_cls = partial(
mlp_cls,
hidden_features=config.n_inner,
activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl,
bias1=mlp_fc1_bias,
bias2=mlp_fc2_bias,
**parallel_kwargs,
**factory_kwargs,
)
elif fused_dense_sqrelu_dense: elif fused_dense_sqrelu_dense:
if process_group is not None: if process_group is not None:
assert fused_mlp, 'Tensor Parallel is not implemented for FusedDenseSqreluDense' assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense"
assert FusedDenseSqreluDense is not None assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner, mlp_cls = partial(
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs) FusedDenseSqreluDense,
hidden_features=config.n_inner,
checkpoint_lvl=mlp_checkpoint_lvl,
**factory_kwargs,
)
else: else:
raise RuntimeError('MLP type not supported') raise RuntimeError("MLP type not supported")
return mlp_cls return mlp_cls
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
sequence_parallel = getattr(config, 'sequence_parallel', True) sequence_parallel = getattr(config, "sequence_parallel", True)
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
use_rms_norm = getattr(config, 'rms_norm', False) use_rms_norm = getattr(config, "rms_norm", False)
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm, norm_cls = partial(
eps=config.layer_norm_epsilon, **factory_kwargs) nn.LayerNorm if not use_rms_norm else RMSNorm,
eps=config.layer_norm_epsilon,
**factory_kwargs,
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32 = getattr(config, 'residual_in_fp32', False) residual_in_fp32 = getattr(config, "residual_in_fp32", False)
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
prenorm = getattr(config, 'prenorm', True) prenorm = getattr(config, "prenorm", True)
parallel_block = getattr(config, 'parallel_block', False) parallel_block = getattr(config, "parallel_block", False)
if not parallel_block: if not parallel_block:
block = Block( block = Block(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, config.hidden_size,
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, mixer_cls,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), mlp_cls,
norm_cls=norm_cls,
prenorm=prenorm,
resid_dropout1=resid_dropout1,
resid_dropout2=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None, sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None mark_shared_params=process_group is not None,
) )
else: else:
assert prenorm assert prenorm
block = ParallelBlock( block = ParallelBlock(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, config.hidden_size,
resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, mixer_cls,
tied_norm=getattr(config, 'parallel_block_tied_norm', False), mlp_cls,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), norm_cls=norm_cls,
resid_dropout1=resid_dropout1,
resid_dropout2=config.resid_pdrop,
tied_norm=getattr(config, "parallel_block_tied_norm", False),
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None, sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None mark_shared_params=process_group is not None,
) )
block.layer_idx = layer_idx block.layer_idx = layer_idx
return block return block
class GPTPreTrainedModel(nn.Module): class GPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__() super().__init__()
if not isinstance(config, GPT2Config): if not isinstance(config, GPT2Config):
...@@ -225,12 +316,23 @@ class GPTPreTrainedModel(nn.Module): ...@@ -225,12 +316,23 @@ class GPTPreTrainedModel(nn.Module):
"To create a model from a Google pretrained model use " "To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
)) )
)
self.config = config self.config = config
@classmethod @classmethod
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None, def from_pretrained(
world_size=1, rank=0, **kwargs): cls,
model_name,
config,
*args,
strict=True,
device=None,
dtype=None,
world_size=1,
rank=0,
**kwargs,
):
""" """
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -239,21 +341,19 @@ class GPTPreTrainedModel(nn.Module): ...@@ -239,21 +341,19 @@ class GPTPreTrainedModel(nn.Module):
model = cls(config, *args, device=device, dtype=dtype, **kwargs) model = cls(config, *args, device=device, dtype=dtype, **kwargs)
# Load state_dict in cpu because we already initialized the model in GPU, and we don't # Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory # want extra stuff taking up more GPU memory
state_dict = state_dict_from_pretrained( state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype)
model_name, device='cpu', dtype=dtype if model_name.startswith("gpt2"):
)
if model_name.startswith('gpt2'):
state_dict = remap_state_dict_hf_gpt2(state_dict, config) state_dict = remap_state_dict_hf_gpt2(state_dict, config)
elif model_name.startswith('facebook/opt'): elif model_name.startswith("facebook/opt"):
state_dict = remap_state_dict_hf_opt(state_dict, config) state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-j-'): elif model_name.startswith("EleutherAI/gpt-j-"):
state_dict = remap_state_dict_hf_gptj(state_dict, config) state_dict = remap_state_dict_hf_gptj(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-neox-'): elif model_name.startswith("EleutherAI/gpt-neox-"):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
elif model_name.startswith('tiiuae/falcon-'): elif model_name.startswith("tiiuae/falcon-"):
state_dict = remap_state_dict_hf_falcon(state_dict, config) state_dict = remap_state_dict_hf_falcon(state_dict, config)
else: else:
raise NotImplementedError(f'Model {model_name} not supported') raise NotImplementedError(f"Model {model_name} not supported")
if world_size > 1: if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
load_return = model.load_state_dict(state_dict, strict=strict) load_return = model.load_state_dict(state_dict, strict=strict)
...@@ -284,36 +384,51 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid ...@@ -284,36 +384,51 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
class GPTModel(GPTPreTrainedModel): class GPTModel(GPTPreTrainedModel):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
super().__init__(config) super().__init__(config)
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True) self.sequence_parallel = getattr(config, "sequence_parallel", True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', assert config.activation_function in [
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu'] "gelu",
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) "gelu_new",
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) "gelu_fast",
* pad_vocab_size_multiple) "gelu_approx",
"relu",
"sqrelu",
"glu",
"swiglu",
"geglu",
]
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False) self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
# These 2 options are for OPT-350m # These 2 options are for OPT-350m
self.prenorm = getattr(config, 'prenorm', True) self.prenorm = getattr(config, "prenorm", True)
use_rms_norm = getattr(config, 'rms_norm', False) use_rms_norm = getattr(config, "rms_norm", False)
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
# For GPT-J, GPT-NeoX # For GPT-J, GPT-NeoX
self.parallel_block = getattr(config, 'parallel_block', False) self.parallel_block = getattr(config, "parallel_block", False)
if process_group is None: if process_group is None:
self.embeddings = GPT2Embeddings( self.embeddings = GPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings, config.hidden_size,
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs vocab_size,
config.max_position_embeddings,
word_embed_proj_dim=word_embed_proj_dim,
**factory_kwargs,
) )
else: else:
self.embeddings = ParallelGPT2Embeddings( self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings, config.hidden_size,
process_group=process_group, sequence_parallel=self.sequence_parallel, vocab_size,
**factory_kwargs config.max_position_embeddings,
process_group=process_group,
sequence_parallel=self.sequence_parallel,
**factory_kwargs,
) )
# We change the order of dropout, residual and layer norm: # We change the order of dropout, residual and layer norm:
...@@ -322,20 +437,25 @@ class GPTModel(GPTPreTrainedModel): ...@@ -322,20 +437,25 @@ class GPTModel(GPTPreTrainedModel):
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the # the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed. # nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm. # This is for performance reason: we can fuse dropout + add + layer_norm.
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, self.layers = nn.ModuleList(
**factory_kwargs) [
for i in range(config.num_hidden_layers)]) create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)
for i in range(config.num_hidden_layers)
]
)
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
if ((not self.parallel_block and dropout_add_layer_norm is None) if (not self.parallel_block and dropout_add_layer_norm is None) or (
or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)): self.parallel_block and dropout_add_layer_norm_parallel_residual is None
raise ImportError('dropout_layer_norm is not installed') ):
raise ImportError("dropout_layer_norm is not installed")
if self.prenorm: if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop) self.drop_f = nn.Dropout(config.resid_pdrop)
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon, self.ln_f = norm_cls(
**factory_kwargs) config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs
)
if process_group is not None: if process_group is not None:
for p in self.ln_f.parameters(): for p in self.ln_f.parameters():
# Mark the norm parameters as "shared_params" so that we sync their values at init. # Mark the norm parameters as "shared_params" so that we sync their values at init.
...@@ -344,8 +464,13 @@ class GPTModel(GPTPreTrainedModel): ...@@ -344,8 +464,13 @@ class GPTModel(GPTPreTrainedModel):
if self.sequence_parallel: if self.sequence_parallel:
p._sequence_parallel = True p._sequence_parallel = True
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(
initializer_range=config.initializer_range)) partial(
_init_weights,
n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range,
)
)
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
...@@ -353,28 +478,37 @@ class GPTModel(GPTPreTrainedModel): ...@@ -353,28 +478,37 @@ class GPTModel(GPTPreTrainedModel):
sync_shared_params(self, self.process_group) sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) return {
for i, layer in enumerate(self.layers)} i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)
}
def forward(self, input_ids, position_ids=None, inference_params=None): def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size. # dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen. # Only the attention layers need to know the seqlen.
embedding_kwargs = ({'combine_batch_seqlen_dim': True} embedding_kwargs = (
if self.process_group is not None and self.sequence_parallel else {}) {"combine_batch_seqlen_dim": True}
if self.process_group is not None and self.sequence_parallel
else {}
)
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
if self.parallel_block: if self.parallel_block:
hidden_states2 = None hidden_states2 = None
residual = None residual = None
mixer_kwargs = ({'seqlen': input_ids.shape[1]} mixer_kwargs = (
if self.process_group is not None and self.sequence_parallel else {}) {"seqlen": input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel
else {}
)
if inference_params is not None: if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params mixer_kwargs["inference_params"] = inference_params
for layer in self.layers: for layer in self.layers:
if self.prenorm: if self.prenorm:
if not self.parallel_block: if not self.parallel_block:
hidden_states, residual = layer(hidden_states, residual, hidden_states, residual = layer(
mixer_kwargs=mixer_kwargs) hidden_states, residual, mixer_kwargs=mixer_kwargs
)
else: else:
hidden_states, hidden_states2, residual = layer( hidden_states, hidden_states2, residual = layer(
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
...@@ -388,45 +522,66 @@ class GPTModel(GPTPreTrainedModel): ...@@ -388,45 +522,66 @@ class GPTModel(GPTPreTrainedModel):
residual = (dropped + residual) if residual is not None else dropped residual = (dropped + residual) if residual is not None else dropped
else: else:
dropped2 = self.drop_f(hidden_states2) dropped2 = self.drop_f(hidden_states2)
residual = ((residual + dropped + dropped2) residual = (
if residual is not None else dropped + dropped2) (residual + dropped + dropped2)
if residual is not None
else dropped + dropped2
)
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else: else:
# Set prenorm=False here since we don't need the residual # Set prenorm=False here since we don't need the residual
if not self.parallel_block: if not self.parallel_block:
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm) fused_add_norm_fn = (
else dropout_add_layer_norm) dropout_add_rms_norm
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm
)
hidden_states = fused_add_norm_fn( hidden_states = fused_add_norm_fn(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias, hidden_states,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, residual,
residual_in_fp32=self.residual_in_fp32 self.ln_f.weight,
self.ln_f.bias,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
) )
else: else:
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual fused_add_norm_fn = (
if isinstance(self.ln_f, RMSNorm) dropout_add_rms_norm_parallel_residual
else dropout_add_layer_norm_parallel_residual) if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm_parallel_residual
)
hidden_states, _ = fused_add_norm_fn( hidden_states, _ = fused_add_norm_fn(
hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias, hidden_states,
None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps, hidden_states2,
prenorm=False, residual_in_fp32=self.residual_in_fp32 residual,
self.ln_f.weight,
self.ln_f.bias,
None,
None,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
) )
return hidden_states return hidden_states
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(config) super().__init__(config)
self.process_group = process_group self.process_group = process_group
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True) self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
lm_head_bias = getattr(config, 'lm_head_bias', False) lm_head_bias = getattr(config, "lm_head_bias", False)
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) vocab_size = (
* pad_vocab_size_multiple) math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
# This option is for OPT-350m # This option is for OPT-350m
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None)
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
if word_embed_proj_dim is not None: if word_embed_proj_dim is not None:
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
...@@ -436,14 +591,23 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -436,14 +591,23 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
else: else:
if ColumnParallelLinear is None: if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed') raise ImportError("fused_dense_lib is not installed")
self.lm_head = ColumnParallelLinear( self.lm_head = ColumnParallelLinear(
embed_dim, vocab_size, process_group, bias=lm_head_bias, embed_dim,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs vocab_size,
process_group,
bias=lm_head_bias,
sequence_parallel=getattr(config, "sequence_parallel", True),
**factory_kwargs,
) )
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(
initializer_range=config.initializer_range)) partial(
_init_weights,
n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range,
)
)
self.tie_weights() self.tie_weights()
def tie_weights(self): def tie_weights(self):
...@@ -453,18 +617,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -453,18 +617,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
sync_shared_params(self, self.process_group) sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, return self.transformer.allocate_inference_cache(
**kwargs) batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
""" """
inference_params: for generation. Adapted from Megatron-LM (and Apex) inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only, last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size) of shape (batch_size, vocab_size)
""" """
hidden_states = self.transformer(input_ids, position_ids=position_ids, hidden_states = self.transformer(
inference_params=inference_params) input_ids, position_ids=position_ids, inference_params=inference_params
)
if last_token_only: if last_token_only:
hidden_states = hidden_states[:, -1] hidden_states = hidden_states[:, -1]
if self.project_out is not None: if self.project_out is not None:
...@@ -473,34 +639,34 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -473,34 +639,34 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
# During inference, we want the full logit for sampling # During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0]) lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0])
CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits) return CausalLMOutput(logits=lm_logits)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
# Remapping from our checkpoints that used a different ordering of layers in the block # Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN # Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP # Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict: if "transformer.ln_0.weight" in state_dict:
n_layers = len(self.transformer.layers) n_layers = len(self.transformer.layers)
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight') ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight")
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias') ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias")
state_dict['transformer.ln_f.weight'] = ln_weight state_dict["transformer.ln_f.weight"] = ln_weight
state_dict['transformer.ln_f.bias'] = ln_bias state_dict["transformer.ln_f.bias"] = ln_bias
for l in reversed(range(n_layers)): for l in reversed(range(n_layers)):
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight') ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight")
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias') ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias")
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias
if l > 0: if l > 0:
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight') ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight")
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias') ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias")
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias
ln_weight = state_dict.pop('transformer.ln_0.weight') ln_weight = state_dict.pop("transformer.ln_0.weight")
ln_bias = state_dict.pop('transformer.ln_0.bias') ln_bias = state_dict.pop("transformer.ln_0.bias")
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias
return super().load_state_dict(state_dict, strict=strict) return super().load_state_dict(state_dict, strict=strict)
...@@ -508,8 +674,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -508,8 +674,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model """Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel. with tensor parallel.
""" """
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0 assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
...@@ -519,64 +685,84 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -519,64 +685,84 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[0] // world_size dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim] state_dict[key] = x[rank * dim : (rank + 1) * dim]
def shard_last_dim(state_dict, key): def shard_last_dim(state_dict, key):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[-1] // world_size dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim] state_dict[key] = x[..., rank * dim : (rank + 1) * dim]
def shard_gatedmlp_fc1_dim(state_dict, key): def shard_gatedmlp_fc1_dim(state_dict, key):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[0] // world_size // 2 dim = x.shape[0] // world_size // 2
state_dict[key] = rearrange( state_dict[key] = rearrange(
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim:(rank + 1) * dim], rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
"two o ... -> (two o) ..." "two o ... -> (two o) ...",
) )
def shard_qkv_headdim(state_dict, key): def shard_qkv_headdim(state_dict, key):
if key in state_dict: if key in state_dict:
n_head = config.n_head n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head) n_head_kv = getattr(config, "n_head_kv", n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0 assert n_head % world_size == 0 and n_head_kv % world_size == 0
if n_head_kv == n_head: if n_head_kv == n_head:
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
dim = x.shape[1] // world_size dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], state_dict[key] = rearrange(
'three d ... -> (three d) ...') x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..."
)
else: else:
n_head_per_rank = n_head // world_size n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size n_head_kv_per_rank = n_head_kv // world_size
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...', x = rearrange(
nheadqkv=n_head + 2 * n_head_kv) state_dict[key],
state_dict[key] = rearrange(torch.cat([ "(nheadqkv headdim) ... -> nheadqkv headdim ...",
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank], nheadqkv=n_head + 2 * n_head_kv,
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank], )
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank], state_dict[key] = rearrange(
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...") torch.cat(
[
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') x[rank * n_head_per_rank : (rank + 1) * n_head_per_rank],
if 'lm_head.weight' in state_dict: x[
shard_first_dim(state_dict, 'lm_head.weight') n_head
if 'transformer.embeddings.position_embeddings.weight' in state_dict: + rank * n_head_kv_per_rank : n_head
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight') + (rank + 1) * n_head_kv_per_rank
],
x[
n_head
+ n_head_kv
+ rank * n_head_kv_per_rank : n_head
+ n_head_kv
+ (rank + 1) * n_head_kv_per_rank
],
],
dim=0,
),
"nheadqkv headdim ... -> (nheadqkv headdim) ...",
)
shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight")
if "lm_head.weight" in state_dict:
shard_first_dim(state_dict, "lm_head.weight")
if "transformer.embeddings.position_embeddings.weight" in state_dict:
shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight")
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight")
if rank != 0: if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None) state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
if config.activation_function in ["glu", "swiglu", "geglu"]: if config.activation_function in ["glu", "swiglu", "geglu"]:
shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
else: else:
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias")
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight")
if rank != 0: if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None) state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None)
return state_dict return state_dict
...@@ -586,8 +772,8 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -586,8 +772,8 @@ def combine_state_dicts_tp(state_dicts, config):
""" """
world_size = len(state_dicts) world_size = len(state_dicts)
keys = state_dicts[0].keys() keys = state_dicts[0].keys()
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0 assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
...@@ -605,90 +791,125 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -605,90 +791,125 @@ def combine_state_dicts_tp(state_dicts, config):
def combine_qkv_headdim(state_dicts, state_dict, key): def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head) n_head_kv = getattr(config, "n_head_kv", n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0 assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict: if key in state_dict:
if n_head_kv == n_head: if n_head_kv == n_head:
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] xs = [
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts
]
state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...")
else: else:
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...', xs = [
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts] rearrange(
state_dict[key] = rearrange(torch.cat([ s[key],
torch.cat([x[:n_head_per_rank] for x in xs], dim=0), "(nheadqkv headdim) ... -> nheadqkv headdim ...",
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0), nheadqkv=n_head + 2 * n_head_kv,
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0), )
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...") for s in state_dicts
]
state_dict[key] = rearrange(
torch.cat(
[
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat(
[
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
for x in xs
],
dim=0,
),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
],
dim=0,
),
"nheadqkv headdim ... -> (nheadqkv headdim) ...",
)
def combine_gated_mlp(state_dicts, state_dict, key): def combine_gated_mlp(state_dicts, state_dict, key):
if key in state_dict: if key in state_dict:
xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts] xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...') state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...")
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight') combine_word_embeddings(
if 'lm_head.weight' in state_dict: state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight"
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight') )
if 'transformer.embeddings.position_embeddings.weight' in state_dict: if "lm_head.weight" in state_dict:
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1) combine_word_embeddings(state_dicts, state_dict, "lm_head.weight")
mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu'] if "transformer.embeddings.position_embeddings.weight" in state_dict:
else partial(combine_dim, dim=0)) combine_dim(
state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1
)
mlp_combine_fn = (
combine_gated_mlp
if config.activation_function in ["glu", "swiglu", "geglu"]
else partial(combine_dim, dim=0)
)
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1) combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1)
mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight') mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight")
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0) combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1) combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1)
return state_dict return state_dict
def remap_state_dict_hf_gpt2(state_dict, config): def remap_state_dict_hf_gpt2(state_dict, config):
# Word embedding and position embedding # Word embedding and position embedding
def key_mapping_pos_emb(key): def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('wte.weight') word_embeddings = state_dict.pop("wte.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key) key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key) key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
for d in range(config.num_hidden_layers): for d in range(config.num_hidden_layers):
W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight') W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight")
state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t() state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t()
W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight') W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight")
state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t() state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
def key_mapping_mlp(key): def key_mapping_mlp(key):
key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key) key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key)
key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key) key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
return key return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
for d in range(config.num_hidden_layers): for d in range(config.num_hidden_layers):
state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias
Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight') Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t() state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight') Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t() state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
def key_mapping_attn(key): def key_mapping_attn(key):
key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key) key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key) key = re.sub(
r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict return state_dict
...@@ -696,66 +917,94 @@ def remap_state_dict_hf_gpt2(state_dict, config): ...@@ -696,66 +917,94 @@ def remap_state_dict_hf_gpt2(state_dict, config):
def remap_state_dict_megatron(state_dict, config): def remap_state_dict_megatron(state_dict, config):
def key_mapping_transformer(key): def key_mapping_transformer(key):
key = re.sub(r'^language_model.encoder.', 'transformer.', key) key = re.sub(r"^language_model.encoder.", "transformer.", key)
key = re.sub(r'^language_model.', 'transformer.', key) key = re.sub(r"^language_model.", "transformer.", key)
return key return key
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
# Word embedding and position embedding # Word embedding and position embedding
def key_mapping_pos_emb(key): def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight') word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example. # It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) vocab_size = (
* pad_vocab_size_multiple) math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( )
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
) )
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm # LayerNorm
def key_mapping_ln(key): def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key) key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)', key = re.sub(
r'transformer.layers.\1.norm1.\2', key) r"^transformer.layers.(\d+).input_layernorm.(weight|bias)",
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)', r"transformer.layers.\1.norm1.\2",
r'transformer.layers.\1.norm2.\2', key) key,
)
key = re.sub(
r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)",
r"transformer.layers.\1.norm2.\2",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP # MLP
def key_mapping_mlp(key): def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)', key = re.sub(
r'transformer.layers.\1.mlp.fc1.\2', key) r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)",
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)', r"transformer.layers.\1.mlp.fc1.\2",
r'transformer.layers.\1.mlp.fc2.\2', key) key,
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)",
r"transformer.layers.\1.mlp.fc2.\2",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention # Attention
def key_mapping_attn(key): def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq', key = re.sub(
r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key) r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq",
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)', r"transformer.layers.\1.mixer.rotary_emb.inv_freq",
r'transformer.layers.\1.mixer.Wqkv.\2', key) key,
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)', )
r'transformer.layers.\1.mixer.out_proj.\2', key) key = re.sub(
r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)",
r"transformer.layers.\1.mixer.Wqkv.\2",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)",
r"transformer.layers.\1.mixer.out_proj.\2",
key,
)
return key return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads headdim = config.hidden_size // config.num_attention_heads
for d in range(config.num_hidden_layers): for d in range(config.num_hidden_layers):
Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight') Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange( state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', Wqkv,
three=3, headdim=headdim "(nheads three headdim) ... -> (three nheads headdim) ...",
three=3,
headdim=headdim,
) )
bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias') bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange( state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)', bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
three=3, headdim=headdim
) )
return state_dict return state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment