Commit b4018a50 authored by Tri Dao's avatar Tri Dao
Browse files

Implement Tensor Parallel for GPT model

parent 78225c53
...@@ -12,10 +12,16 @@ import torch.nn.functional as F ...@@ -12,10 +12,16 @@ import torch.nn.functional as F
from transformers import GPT2Config from transformers import GPT2Config
from flash_attn.modules.mha import MHA from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_sequence_parallel_params
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
ColumnParallelLinear = None
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm from flash_attn.ops.layer_norm import dropout_add_layer_norm
...@@ -28,32 +34,45 @@ except ImportError: ...@@ -28,32 +34,45 @@ except ImportError:
FusedDenseSqreluDense = None FusedDenseSqreluDense = None
def create_mixer_cls(config, layer_idx=None): def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5) softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
if config.scale_attn_by_inverse_layer_idx: if config.scale_attn_by_inverse_layer_idx:
assert layer_idx is not None assert layer_idx is not None
softmax_scale /= float(layer_idx + 1) softmax_scale /= float(layer_idx + 1)
dwconv = getattr(config, 'attn_dwconv', False) dwconv = getattr(config, 'attn_dwconv', False)
if dwconv:
assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0) rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
use_flash_attn = getattr(config, 'use_flash_attn', False) use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, if not fused_bias_fc:
softmax_scale=softmax_scale, causal=True, dwconv=dwconv, assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
mha_cls = MHA if process_group is None else ParallelMHA
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
if process_group is None else {})
parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn) use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
return mixer_cls return mixer_cls
def create_mlp_cls(config, layer_idx=None): def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense) assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
if process_group is not None:
assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense'
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense: if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
mlp_cls = partial(Mlp, hidden_features=inner_dim, mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh')) activation=partial(F.gelu, approximate='tanh'), **factory_kwargs)
else: else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...@@ -63,24 +82,28 @@ def create_mlp_cls(config, layer_idx=None): ...@@ -63,24 +82,28 @@ def create_mlp_cls(config, layer_idx=None):
if fused_dense_gelu_dense: if fused_dense_gelu_dense:
if FusedDenseGeluDense is None: if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
checkpoint_lvl=mlp_checkpoint_lvl) parallel_kwargs = {'process_group': process_group} if process_group is not None else {}
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense: elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim, mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl) checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
else: else:
raise RuntimeError('MLP type not supported') raise RuntimeError('MLP type not supported')
return mlp_cls return mlp_cls
def create_block(config, layer_idx=None): def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
mixer_cls = create_mixer_cls(config, layer_idx) factory_kwargs = {'device': device, 'dtype': dtype}
mlp_cls = create_mlp_cls(config, layer_idx) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout=config.resid_pdrop, prenorm=True, resid_dropout=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False)) fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
sequence_parallel=process_group is not None)
block.layer_idx = layer_idx block.layer_idx = layer_idx
return block return block
...@@ -109,15 +132,23 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid ...@@ -109,15 +132,23 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
class GPTModel(nn.Module): class GPTModel(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
super().__init__() super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0: if config.vocab_size % self.pad_vocab_size_multiple != 0:
config.vocab_size += (self.pad_vocab_size_multiple config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple)) - (config.vocab_size % self.pad_vocab_size_multiple))
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size, if process_group is None:
config.max_position_embeddings) self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings, **factory_kwargs)
else:
self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, config.vocab_size, config.max_position_embeddings,
process_group=process_group, **factory_kwargs
)
self.emb_drop = nn.Dropout(config.embd_pdrop) self.emb_drop = nn.Dropout(config.embd_pdrop)
# We change the order of residual and layer norm: # We change the order of residual and layer norm:
...@@ -131,16 +162,29 @@ class GPTModel(nn.Module): ...@@ -131,16 +162,29 @@ class GPTModel(nn.Module):
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError('dropout_add_layer_norm is not installed')
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight) # self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
# is the final layer norm. # is the final layer norm.
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs)
self.layers = nn.ModuleList([create_block(config, layer_idx=i) # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if process_group is not None:
for p in self.ln_0.parameters():
p._sequence_parallel = True
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs)
for i in range(config.num_hidden_layers)]) for i in range(config.num_hidden_layers)])
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range)) initializer_range=config.initializer_range))
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def forward(self, input_ids, position_ids=None): def forward(self, input_ids, position_ids=None):
hidden_states = self.embeddings(input_ids, position_ids=position_ids) # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
residual = self.emb_drop(hidden_states).float() residual = self.emb_drop(hidden_states).float()
...@@ -151,21 +195,32 @@ class GPTModel(nn.Module): ...@@ -151,21 +195,32 @@ class GPTModel(nn.Module):
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True, self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
residual_in_fp32=True residual_in_fp32=True
) )
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {})
for layer in self.layers: for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual) hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
return hidden_states return hidden_states
class GPTLMHeadModel(nn.Module): class GPTLMHeadModel(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
self.transformer = GPTModel(config) self.process_group = process_group
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
if process_group is None:
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(config.n_embd, config.vocab_size, process_group,
bias=False, **factory_kwargs)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range)) initializer_range=config.initializer_range))
self.tie_weights() self.tie_weights()
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def tie_weights(self): def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
......
...@@ -72,3 +72,26 @@ class ReduceScatterFunc(torch.autograd.Function): ...@@ -72,3 +72,26 @@ class ReduceScatterFunc(torch.autograd.Function):
# Supports autograd, but does not support async # Supports autograd, but does not support async
reduce_scatter = ReduceScatterFunc.apply reduce_scatter = ReduceScatterFunc.apply
def sync_sequence_parallel_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
for _, p in sorted(params_seqparallel.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast(
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
)
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
for _, p in sorted(params_seqparallel.items()):
with torch.no_grad():
torch.distributed.all_reduce(p.grad, group=process_group)
...@@ -46,8 +46,9 @@ def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_back ...@@ -46,8 +46,9 @@ def test_cross_entropy_loss_apex(vocab_size, world_size, smoothing, inplace_back
y[torch.randperm(batch_size * seqlen)[:10]] = -100 y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none') model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction='none')
model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none', model = CrossEntropyLoss(label_smoothing=smoothing, reduction='none',
inplace_backward=inplace_backward) inplace_backward=inplace_backward,
out = model(x, y, process_group=parallel_state.get_tensor_model_parallel_group()) process_group=parallel_state.get_tensor_model_parallel_group())
out = model(x, y)
out_pt = model_pt(x_pt.float(), y) out_pt = model_pt(x_pt.float(), y)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
......
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytest
from einops import rearrange
from transformers import GPT2Config
from apex.transformer import parallel_state
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [1])
@pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024])
def test_block_parallel(dim, has_pos_emb, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
assert num_heads % world_size == 0
vocab_size = 50264
assert vocab_size % world_size == 0
num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(n_embd=dim, n_head=num_heads, n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size)
model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module):
if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight)
nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device)
total_nparams = sum(p.numel() for p in model_pt.parameters())
sharded_nparams = sum(p.numel() for p in model.parameters())
sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
)
sequence_parallel_nparams = sum(p.numel() for p in model.parameters()
if getattr(p, '_sequence_parallel', False))
sequence_parallel_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
sequence_parallel_nparams_all, torch.tensor([sequence_parallel_nparams], device=device),
group=process_group
)
assert torch.all(sequence_parallel_nparams_all == sequence_parallel_nparams)
assert total_nparams == ((sharded_nparams_all - sequence_parallel_nparams_all).sum().item()
+ sequence_parallel_nparams)
# vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.transformer.embeddings.word_embeddings.weight.copy_(
model_pt.transformer.embeddings.word_embeddings.weight[rank * partition_vocab_size:(rank + 1) * partition_vocab_size]
)
if has_pos_emb:
model.transformer.embeddings.position_embeddings.weight.copy_(
model_pt.transformer.embeddings.position_embeddings.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
)
model.transformer.ln_0.weight.copy_(model_pt.transformer.ln_0.weight)
model.transformer.ln_0.bias.copy_(model_pt.transformer.ln_0.bias)
for i in range(num_layers):
model.transformer.layers[i].mixer.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
)
model.transformer.layers[i].mixer.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
)
model.transformer.layers[i].mixer.out_proj.weight.copy_(
model_pt.transformer.layers[i].mixer.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
)
if rank == 0:
model.transformer.layers[i].mixer.out_proj.bias.copy_(model_pt.transformer.layers[i].mixer.out_proj.bias)
model.transformer.layers[i].mlp.fc1.weight.copy_(
model_pt.transformer.layers[i].mlp.fc1.weight[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
)
model.transformer.layers[i].mlp.fc1.bias.copy_(
model_pt.transformer.layers[i].mlp.fc1.bias[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
)
model.transformer.layers[i].mlp.fc2.weight.copy_(
model_pt.transformer.layers[i].mlp.fc2.weight[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim]
)
if rank == 0:
model.transformer.layers[i].mlp.fc2.bias.copy_(model_pt.transformer.layers[i].mlp.fc2.bias)
model.transformer.layers[i].norm1.weight.copy_(model_pt.transformer.layers[i].norm1.weight)
model.transformer.layers[i].norm1.bias.copy_(model_pt.transformer.layers[i].norm1.bias)
model.transformer.layers[i].norm2.weight.copy_(model_pt.transformer.layers[i].norm2.weight)
model.transformer.layers[i].norm2.bias.copy_(model_pt.transformer.layers[i].norm2.bias)
# Don't need to copy the lm_head weight since it's tied to the word embedding weight
with torch.autocast(device_type='cuda', dtype=dtype):
out = model(input_ids[:, :-1]).logits
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[:, rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
rtol=rtol, atol=atol
)
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction='none', process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction='none')
loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
loss_pt.backward(g)
loss.backward(g)
allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel()
assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad,
model_pt.transformer.embeddings.word_embeddings.weight.grad[rank * partition_vocab_size:(rank + 1) * partition_vocab_size],
rtol=rtol, atol=atol * 5
)
if has_pos_emb:
assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad,
model_pt.transformer.embeddings.position_embeddings.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol
)
assert torch.allclose(model.transformer.ln_0.weight.grad, model_pt.transformer.ln_0.weight.grad,
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_0.bias.grad, model_pt.transformer.ln_0.bias.grad,
rtol=rtol, atol=atol)
for i in range(num_layers):
# if rank == 0: breakpoint()
# torch.distributed.barrier()
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad,
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], 'three o i -> (three o) i'),
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad,
model_pt.transformer.layers[i].mixer.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mixer.out_proj.bias.grad, model_pt.transformer.layers[i].mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad,
model_pt.transformer.layers[i].mlp.fc1.weight.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad,
model_pt.transformer.layers[i].mlp.fc1.bias.grad[rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad,
model_pt.transformer.layers[i].mlp.fc2.weight.grad[:, rank * partition_hidden_dim:(rank + 1) * partition_hidden_dim],
rtol=rtol, atol=atol * 10
)
if rank == 0:
assert torch.allclose(model.transformer.layers[i].mlp.fc2.bias.grad, model_pt.transformer.layers[i].mlp.fc2.bias.grad,
rtol=rtol, atol=atol * 5)
assert torch.allclose(model.transformer.layers[i].norm1.weight.grad, model_pt.transformer.layers[i].norm1.weight.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm1.bias.grad, model_pt.transformer.layers[i].norm1.bias.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.weight.grad, model_pt.transformer.layers[i].norm2.weight.grad, rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.layers[i].norm2.bias.grad, model_pt.transformer.layers[i].norm2.bias.grad, rtol=rtol, atol=atol)
...@@ -17,6 +17,7 @@ from apex.transformer import tensor_parallel ...@@ -17,6 +17,7 @@ from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.mlp import FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
...@@ -124,13 +125,7 @@ def test_block_parallel(dim, world_size, dtype): ...@@ -124,13 +125,7 @@ def test_block_parallel(dim, world_size, dtype):
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
# We want to iterate over parameters with _sequence_parallel=True in the same order, allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)}
for _, p in sorted(params_seqparallel.items()):
if getattr(p, '_sequence_parallel', False):
torch.distributed.all_reduce(p.grad, group=parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
......
...@@ -106,7 +106,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function): ...@@ -106,7 +106,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
inplace_backward=False): inplace_backward=False, process_group=None):
super().__init__() super().__init__()
if reduction not in ['mean', 'none']: if reduction not in ['mean', 'none']:
raise NotImplementedError("Only support reduction = 'mean' or 'none'") raise NotImplementedError("Only support reduction = 'mean' or 'none'")
...@@ -114,13 +114,14 @@ class CrossEntropyLoss(nn.Module): ...@@ -114,13 +114,14 @@ class CrossEntropyLoss(nn.Module):
self.reduction = reduction self.reduction = reduction
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward self.inplace_backward = inplace_backward
self.process_group = process_group
def forward(self, input, target, process_group=None): def forward(self, input, target):
assert input.is_cuda and target.is_cuda assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float # SoftmaxCrossEntropyLoss implicitly casts to float
loss = SoftmaxCrossEntropyLossFn.apply( loss = SoftmaxCrossEntropyLossFn.apply(
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
process_group self.process_group
) )
if self.reduction == 'mean': if self.reduction == 'mean':
return loss.sum() / (target != self.ignore_index).sum() return loss.sum() / (target != self.ignore_index).sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment