Commit 93383bd5 authored by Tri Dao's avatar Tri Dao
Browse files

[TP] Implement TensorParallel without sequence parallel

parent ce26d3d7
...@@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA ...@@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.block import Block from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_sequence_parallel_params from flash_attn.utils.distributed import sync_shared_params
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.generation import GenerationMixin
...@@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -62,7 +62,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
mha_cls = MHA if process_group is None else ParallelMHA mha_cls = MHA if process_group is None else ParallelMHA
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv} serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
if process_group is None else {}) if process_group is None else {})
parallel_kwargs = {'process_group': process_group} if process_group is not None else {} parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
...@@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp ...@@ -99,7 +101,9 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if FusedDenseGeluDense is None: if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
parallel_kwargs = {'process_group': process_group} if process_group is not None else {} parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl, mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
**parallel_kwargs, **factory_kwargs) **parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense: elif fused_dense_sqrelu_dense:
...@@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp ...@@ -113,13 +117,15 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
sequence_parallel = getattr(config, 'sequence_parallel', True)
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout=config.resid_pdrop, prenorm=True, resid_dropout=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
sequence_parallel=process_group is not None) sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None)
block.layer_idx = layer_idx block.layer_idx = layer_idx
return block return block
...@@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel): ...@@ -180,6 +186,7 @@ class GPTModel(GPTPreTrainedModel):
super().__init__(config) super().__init__(config)
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu'] assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0: if config.vocab_size % self.pad_vocab_size_multiple != 0:
...@@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel): ...@@ -192,7 +199,8 @@ class GPTModel(GPTPreTrainedModel):
else: else:
self.embeddings = ParallelGPT2Embeddings( self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, config.vocab_size, config.max_position_embeddings, config.hidden_size, config.vocab_size, config.max_position_embeddings,
process_group=process_group, **factory_kwargs process_group=process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs
) )
self.emb_drop = nn.Dropout(config.embd_pdrop) self.emb_drop = nn.Dropout(config.embd_pdrop)
...@@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel): ...@@ -209,10 +217,13 @@ class GPTModel(GPTPreTrainedModel):
# is the final layer norm. # is the final layer norm.
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs) **factory_kwargs)
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if process_group is not None: if process_group is not None:
for p in self.ln_0.parameters(): for p in self.ln_0.parameters():
p._sequence_parallel = True # Mark the norm parameters as "shared_params" so that we sync their values at init.
p._shared_params = True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if self.sequence_parallel:
p._sequence_parallel = True
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs) **factory_kwargs)
...@@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel): ...@@ -224,14 +235,14 @@ class GPTModel(GPTPreTrainedModel):
def tie_weights(self): def tie_weights(self):
if self.process_group is not None: if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group) sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None): def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size. # dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen. # Only the attention layers need to know the seqlen.
embedding_kwargs = ({'combine_batch_seqlen_dim': True} embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None else {}) if self.process_group is not None and self.sequence_parallel else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln: if not self.fused_dropout_add_ln:
...@@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel): ...@@ -243,7 +254,8 @@ class GPTModel(GPTPreTrainedModel):
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True, self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
residual_in_fp32=True residual_in_fp32=True
) )
mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None else {}) mixer_kwargs = ({'seqlen': input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel else {})
if inference_params is not None: if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params mixer_kwargs['inference_params'] = inference_params
for layer in self.layers: for layer in self.layers:
...@@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -263,8 +275,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
else: else:
if ColumnParallelLinear is None: if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed') raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(config.n_embd, config.vocab_size, process_group, self.lm_head = ColumnParallelLinear(
bias=False, **factory_kwargs) config.n_embd, config.vocab_size, process_group, bias=False,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range)) initializer_range=config.initializer_range))
...@@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -273,7 +287,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def tie_weights(self): def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None: if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group) sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None): def forward(self, input_ids, position_ids=None, inference_params=None):
""" """
......
...@@ -23,7 +23,8 @@ class Block(nn.Module): ...@@ -23,7 +23,8 @@ class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False): fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False,
mark_shared_params=False):
""" """
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us This is for performance reason: for post-norm architecture, returning the input allows us
...@@ -51,6 +52,12 @@ class Block(nn.Module): ...@@ -51,6 +52,12 @@ class Block(nn.Module):
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
# For now this is not an issue because we always use sequence_parallel=True during training
# and only use sequence_parallel=False during inference.
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
if sequence_parallel: if sequence_parallel:
for p in self.norm1.parameters(): for p in self.norm1.parameters():
...@@ -58,6 +65,13 @@ class Block(nn.Module): ...@@ -58,6 +65,13 @@ class Block(nn.Module):
if hasattr(self, 'norm2'): if hasattr(self, 'norm2'):
for p in self.norm2.parameters(): for p in self.norm2.parameters():
p._sequence_parallel = True p._sequence_parallel = True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, 'norm2'):
for p in self.norm2.parameters():
p._shared_params = True
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
mixer_kwargs=None): mixer_kwargs=None):
......
...@@ -6,7 +6,7 @@ from torch import Tensor ...@@ -6,7 +6,7 @@ from torch import Tensor
from einops import rearrange from einops import rearrange
from flash_attn.utils.distributed import reduce_scatter from flash_attn.utils.distributed import reduce_scatter, all_reduce
class GPT2Embeddings(nn.Module): class GPT2Embeddings(nn.Module):
...@@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding): ...@@ -130,13 +130,14 @@ class ColumnParallelEmbedding(nn.Embedding):
class ParallelGPT2Embeddings(nn.Module): class ParallelGPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
padding_idx=None, device=None, dtype=None): padding_idx=None, sequence_parallel=True, device=None, dtype=None):
""" """
If max_position_embeddings <= 0, there's no position embeddings If max_position_embeddings <= 0, there's no position embeddings
""" """
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.word_embeddings = VocabParallelEmbedding( self.word_embeddings = VocabParallelEmbedding(
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
**factory_kwargs **factory_kwargs
...@@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module): ...@@ -167,4 +168,5 @@ class ParallelGPT2Embeddings(nn.Module):
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
if combine_batch_seqlen_dim: if combine_batch_seqlen_dim:
embeddings = rearrange(embeddings, 'b s d -> (b s) d') embeddings = rearrange(embeddings, 'b s d -> (b s) d')
return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
...@@ -497,11 +497,10 @@ class ParallelMHA(nn.Module): ...@@ -497,11 +497,10 @@ class ParallelMHA(nn.Module):
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0, def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0, softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
rotary_emb_scale_base=0, rotary_emb_scale_base=0, use_flash_attn=False, checkpointing=False,
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None: sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
self.process_group = process_group
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.causal = causal self.causal = causal
self.layer_idx = layer_idx self.layer_idx = layer_idx
...@@ -521,12 +520,13 @@ class ParallelMHA(nn.Module): ...@@ -521,12 +520,13 @@ class ParallelMHA(nn.Module):
if ColumnParallelLinear is None or RowParallelLinear is None: if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias, self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
**factory_kwargs) sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout) attention_dropout=dropout)
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs) self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x, seqlen=None, **kwargs): def forward(self, x, seqlen=None, **kwargs):
""" """
......
...@@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -15,26 +15,29 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.gelu_activation import gelu_bwd from flash_attn.ops.gelu_activation import gelu_bwd
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
class FusedDenseFunc(torch.autograd.Function): class FusedDenseFunc(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None): def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
sequence_parallel=True):
""" """
If process_group is not None, we're doing Tensor Parallel with sequence parallelism: If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
we do an all_gather_raw of x before doing the matmul. with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
""" """
ctx.compute_weight_gradient = weight.requires_grad ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual ctx.return_residual = return_residual
ctx.process_group = process_group ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous() x = x.contiguous()
if process_group is not None: if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion # We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True) total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else: else:
...@@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function): ...@@ -44,7 +47,7 @@ class FusedDenseFunc(torch.autograd.Function):
weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous() weight = weight.contiguous()
if process_group is not None: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
...@@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function): ...@@ -66,9 +69,10 @@ class FusedDenseFunc(torch.autograd.Function):
grad_input, = args grad_input, = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors x, weight = ctx.saved_tensors
if process_group is not None: if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True) total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else: else:
total_x = x total_x = x
...@@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function): ...@@ -86,13 +90,13 @@ class FusedDenseFunc(torch.autograd.Function):
grad_output, weight) grad_output, weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None: if process_group is not None:
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
async_op=True) grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else: else:
grad_input = None grad_input = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient assert ctx.compute_weight_gradient
if process_group is not None: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
...@@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function): ...@@ -102,15 +106,17 @@ class FusedDenseFunc(torch.autograd.Function):
grad_bias = grad_output if ctx.needs_input_grad[2] else None grad_bias = grad_output if ctx.needs_input_grad[2] else None
if process_group is not None and ctx.needs_input_grad[0]: if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait() handle_grad_input.wait()
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None, None
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group: Optional[ProcessGroup] = None): return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled())) or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group) return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
sequence_parallel)
else: else:
assert process_group is None assert process_group is None
out = F.linear(x, weight, bias) out = F.linear(x, weight, bias)
...@@ -136,7 +142,7 @@ class FusedDense(nn.Linear): ...@@ -136,7 +142,7 @@ class FusedDense(nn.Linear):
class ColumnParallelLinear(nn.Linear): class ColumnParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, device=None, dtype=None) -> None: bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
if out_features % world_size != 0: if out_features % world_size != 0:
raise ValueError(f'out_features ({out_features}) must be divisible by ' raise ValueError(f'out_features ({out_features}) must be divisible by '
...@@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear): ...@@ -144,19 +150,20 @@ class ColumnParallelLinear(nn.Linear):
super().__init__(in_features, out_features // world_size, bias=bias, super().__init__(in_features, out_features // world_size, bias=bias,
device=device, dtype=dtype) device=device, dtype=dtype)
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x): def forward(self, x):
""" # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
We're doing Tensor Parallel with sequence parallelism: we do an all_gather of # we do an all_gather of x before doing the matmul.
x before doing the matmul. # If not, then the input is already gathered.
""" return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) sequence_parallel=self.sequence_parallel)
class RowParallelLinear(nn.Linear): class RowParallelLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
bias: bool = True, device=None, dtype=None) -> None: bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group) rank = torch.distributed.get_rank(process_group)
if in_features % world_size != 0: if in_features % world_size != 0:
...@@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear): ...@@ -166,6 +173,7 @@ class RowParallelLinear(nn.Linear):
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
device=device, dtype=dtype) device=device, dtype=dtype)
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel
def forward(self, x): def forward(self, x):
""" """
...@@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear): ...@@ -173,7 +181,8 @@ class RowParallelLinear(nn.Linear):
a reduce_scatter of the result. a reduce_scatter of the result.
""" """
out = fused_dense_func(x, self.weight, self.bias) out = fused_dense_func(x, self.weight, self.bias)
return reduce_scatter(out, self.process_group) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
class FusedDenseGeluDenseFunc(torch.autograd.Function): class FusedDenseGeluDenseFunc(torch.autograd.Function):
...@@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -181,10 +190,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd @custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False, def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
checkpoint_lvl=0, heuristic=0, process_group=None): checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True):
""" """
If process_group is not None, we're doing Tensor Parallel with sequence parallelism: If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
we do an all_gather of x before doing the matmul. with sequence parallelism: we do an all_gather of x before doing the matmul.
If sequence_parallel=False, then the input is already gathered.
checkpoint_lvl: checkpoint_lvl:
0: no recomputation in the bwd 0: no recomputation in the bwd
...@@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -197,13 +207,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
assert checkpoint_lvl in [0, 1, 2] assert checkpoint_lvl in [0, 1, 2]
ctx.return_residual = return_residual ctx.return_residual = return_residual
ctx.process_group = process_group ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
ctx.checkpoint_lvl = checkpoint_lvl ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic ctx.heuristic = heuristic
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous() x = x.contiguous()
if process_group is not None: if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion # We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True) total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else: else:
...@@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -218,7 +229,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
bias1 = bias1.contiguous() if bias1 is not None else None bias1 = bias1.contiguous() if bias1 is not None else None
weight2 = weight2.contiguous() weight2 = weight2.contiguous()
bias2 = bias2.contiguous() if bias2 is not None else None bias2 = bias2.contiguous() if bias2 is not None else None
if process_group is not None: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
...@@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -257,13 +268,14 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_input, = args grad_input, = args
grad_input = grad_input.contiguous() grad_input = grad_input.contiguous()
process_group = ctx.process_group process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
x, weight1, weight2, *rest = ctx.saved_tensors x, weight1, weight2, *rest = ctx.saved_tensors
if process_group is None: if process_group is None or not sequence_parallel:
total_x = x total_x = x
batch_shape = grad_output.shape[:-1] batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
if checkpoint_lvl in [0, 1]: if checkpoint_lvl in [0, 1]:
if process_group is not None: if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True) total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
if checkpoint_lvl == 0: if checkpoint_lvl == 0:
gelu_in, output1 = rest gelu_in, output1 = rest
...@@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -272,7 +284,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
output1 = F.gelu(gelu_in, approximate='tanh') output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2: elif checkpoint_lvl == 2:
bias1, = rest bias1, = rest
if process_group is not None: if process_group is not None and sequence_parallel:
total_x, _ = all_gather_raw(x, process_group) total_x, _ = all_gather_raw(x, process_group)
if ctx.heuristic == -1: if ctx.heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1) gelu_in = F.linear(total_x, weight1, bias1)
...@@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -314,13 +326,13 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_gelu, weight1) grad_gelu, weight1)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None: if process_group is not None:
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
async_op=True) grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
else: else:
grad_input = None grad_input = None
if ctx.heuristic == -1: if ctx.heuristic == -1:
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
if process_group is not None: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu, total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
...@@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -331,7 +343,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
else: else:
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
if process_group is not None: if process_group is not None and sequence_parallel:
handle_x.wait() handle_x.wait()
grad_weight1 = F.linear(grad_gelu.t(), grad_weight1 = F.linear(grad_gelu.t(),
total_x.reshape(batch_dim, total_x.shape[-1]).t()) total_x.reshape(batch_dim, total_x.shape[-1]).t())
...@@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -340,7 +352,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if process_group is not None and ctx.needs_input_grad[0]: if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait() handle_grad_input.wait()
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
None, None, None, None, None) None, None, None, None, None, None)
def fused_dense_gelu_dense_func( def fused_dense_gelu_dense_func(
...@@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func( ...@@ -348,15 +360,16 @@ def fused_dense_gelu_dense_func(
bias2: Optional[Tensor] = None, bias2: Optional[Tensor] = None,
save_pre_act: bool = True, return_residual: bool = False, save_pre_act: bool = True, return_residual: bool = False,
checkpoint_lvl: int = 0, heuristic: int = 0, checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True
): ):
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled())) or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and dtype_eligible): and (bias2 is None or bias2.is_cuda) and dtype_eligible):
return FusedDenseGeluDenseFunc.apply( return FusedDenseGeluDenseFunc.apply(
x, weight1, bias1, weight2, bias2, x, weight1, bias1, weight2, bias2, save_pre_act, return_residual,
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group checkpoint_lvl, heuristic, process_group, sequence_parallel
) )
else: else:
assert process_group is None assert process_group is None
...@@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module): ...@@ -418,7 +431,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None, def __init__(self, in_features, hidden_features, out_features=None,
process_group: ProcessGroup = None, bias1=True, bias2=True, process_group: ProcessGroup = None, bias1=True, bias2=True,
checkpoint_lvl=0, heuristic=0, device=None, dtype=None): sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
""" """
process_group is required. We're doing Tensor Parallel with sequence parallelism: process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul. we do an all_gather of x before doing the matmul, gelu, then matmul.
...@@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module): ...@@ -441,6 +454,7 @@ class ParallelFusedDenseGeluDense(nn.Module):
if out_features is None: if out_features is None:
out_features = in_features out_features = in_features
self.process_group = process_group self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic self.heuristic = heuristic
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
...@@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module): ...@@ -452,6 +466,9 @@ class ParallelFusedDenseGeluDense(nn.Module):
out = fused_dense_gelu_dense_func( out = fused_dense_gelu_dense_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl, save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
heuristic=self.heuristic, process_group=self.process_group heuristic=self.heuristic,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel
) )
return reduce_scatter(out, self.process_group) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
return reduce_fn(out, self.process_group)
...@@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed): ...@@ -14,7 +14,7 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
# Raw operation, oes does support autograd, but does support async # Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
...@@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = ...@@ -24,7 +24,7 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
return output, handle return output, handle
# Raw operation, oes does support autograd, but does support async # Raw operation, does not support autograd, but does support async
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0 assert input_.shape[0] % world_size == 0
...@@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo ...@@ -36,6 +36,13 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
return output, handle return output, handle
# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle
class AllGatherFunc(torch.autograd.Function): class AllGatherFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate.""" """Gather the input from sequence parallel region and concatenate."""
...@@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function): ...@@ -74,12 +81,30 @@ class ReduceScatterFunc(torch.autograd.Function):
reduce_scatter = ReduceScatterFunc.apply reduce_scatter = ReduceScatterFunc.apply
def sync_sequence_parallel_params(model: torch.nn.Module, process_group: ProcessGroup): class AllReduceFunc(torch.autograd.Function):
# We want to iterate over parameters with _sequence_parallel=True in the same order, """Gather the input from sequence parallel region and concatenate."""
@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_reduce_raw(input_, process_group)
return output
@staticmethod
def backward(ctx, grad_output: Tensor):
return grad_output, None
# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {name: p for name, p in model.named_parameters() pamams_shared = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)} if getattr(p, '_shared_params', False)}
for _, p in sorted(params_seqparallel.items()): for _, p in sorted(pamams_shared.items()):
with torch.no_grad(): with torch.no_grad():
# Broadcast needs src to be global rank, not group rank # Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast( torch.distributed.broadcast(
...@@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc ...@@ -94,8 +119,9 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
params_seqparallel = {name: p for name, p in model.named_parameters() params_seqparallel = {name: p for name, p in model.named_parameters()
if getattr(p, '_sequence_parallel', False)} if getattr(p, '_sequence_parallel', False)}
grads = [p.grad for _, p in sorted(params_seqparallel.items())] grads = [p.grad for _, p in sorted(params_seqparallel.items())]
with torch.no_grad(): if grads:
coalesced = torch._utils._flatten_dense_tensors(grads) with torch.no_grad():
torch.distributed.all_reduce(coalesced, group=process_group) coalesced = torch._utils._flatten_dense_tensors(grads)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): torch.distributed.all_reduce(coalesced, group=process_group)
buf.copy_(synced) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
...@@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -23,10 +23,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False]) @pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True]) # @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize('dim', [1024])
def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
assert dim % head_dim == 0 assert dim % head_dim == 0
num_heads = dim // head_dim num_heads = dim // head_dim
...@@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): ...@@ -59,7 +61,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True, fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5, rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size) pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel)
model_pt = GPTLMHeadModel(config, device=device) model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module): def init_layer_norm(module):
...@@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): ...@@ -75,16 +78,15 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
) )
sequence_parallel_nparams = sum(p.numel() for p in model.parameters() shared_nparams = sum(p.numel() for p in model.parameters()
if getattr(p, '_sequence_parallel', False)) if getattr(p, '_shared_params', False))
sequence_parallel_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
sequence_parallel_nparams_all, torch.tensor([sequence_parallel_nparams], device=device), shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
group=process_group
) )
assert torch.all(sequence_parallel_nparams_all == sequence_parallel_nparams) assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == ((sharded_nparams_all - sequence_parallel_nparams_all).sum().item() assert total_nparams == ((sharded_nparams_all - shared_nparams_all).sum().item()
+ sequence_parallel_nparams) + shared_nparams)
# vocab_size has been rounded up here # vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size partition_vocab_size = config.vocab_size // world_size
...@@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype): ...@@ -96,6 +98,8 @@ def test_gpt_parallel(dim, has_pos_emb, world_size, dtype):
with torch.autocast(device_type='cuda', dtype=dtype): with torch.autocast(device_type='cuda', dtype=dtype):
out = model(input_ids[:, :-1]).logits out = model(input_ids[:, :-1]).logits
if not sequence_parallel:
out = rearrange(out, 'b s d -> (b s) d')
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
......
...@@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -23,11 +23,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize('dim', [1024])
def test_block_parallel(dim, world_size, dtype): def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64 head_dim = 64
assert dim % head_dim == 0 assert dim % head_dim == 0
num_heads = dim // head_dim num_heads = dim // head_dim
...@@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype): ...@@ -41,7 +43,7 @@ def test_block_parallel(dim, world_size, dtype):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype,
...@@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype): ...@@ -51,8 +53,12 @@ def test_block_parallel(dim, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() if sequence_parallel:
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_() x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
residual = tensor_parallel.scatter_to_sequence_parallel_region(residual_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
residual = residual_pt.detach().clone().requires_grad_()
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype) use_flash_attn=True, device=device, dtype=dtype)
...@@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype): ...@@ -69,12 +75,12 @@ def test_block_parallel(dim, world_size, dtype):
mixer_cls = partial(ParallelMHA, num_heads=num_heads, mixer_cls = partial(ParallelMHA, num_heads=num_heads,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim, mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
sequence_parallel=True) sequence_parallel=sequence_parallel, mark_shared_params=True)
partition_dim = dim // world_size partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size partition_hidden_dim = 4 * dim // world_size
...@@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype): ...@@ -115,25 +121,34 @@ def test_block_parallel(dim, world_size, dtype):
out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]] out_pt, out_residual_pt = [rearrange(x, 'b s d -> (b s) d') for x in [out_pt, out_residual_pt]]
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
assert torch.allclose( assert torch.allclose(
out_residual, out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], out_residual,
out_residual_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_residual_pt,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
out_pt.backward(g) (out_pt + 2 * out_residual_pt).backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) (out + 2 * out_residual).backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group()) allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], x.grad,
rtol=rtol, atol=atol x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
) )
assert torch.allclose( assert torch.allclose(
residual.grad, residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], residual.grad,
residual_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else residual_pt.grad,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
......
...@@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -19,10 +19,12 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_pos_emb', [True, False]) @pytest.mark.parametrize('has_pos_emb', [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True]) # @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize('dim', [1024]) @pytest.mark.parametrize('dim', [1024])
def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
vocab_size = 50264 vocab_size = 50264
seqlen = 2048 seqlen = 2048
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
...@@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): ...@@ -46,7 +48,7 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
device=device, dtype=dtype) device=device, dtype=dtype)
model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0, model = ParallelGPT2Embeddings(dim, vocab_size, seqlen if has_pos_emb else 0,
parallel_state.get_tensor_model_parallel_group(), parallel_state.get_tensor_model_parallel_group(),
device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
partition_vocab_size = vocab_size // world_size partition_vocab_size = vocab_size // world_size
partition_dim = dim // world_size partition_dim = dim // world_size
with torch.no_grad(): with torch.no_grad():
...@@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype): ...@@ -62,13 +64,16 @@ def test_embedding_parallel(dim, world_size, has_pos_emb, dtype):
out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d') out_pt = rearrange(model_pt(input_ids), 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
g = torch.randn_like(out_pt) g = torch.randn_like(out_pt)
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
......
...@@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -21,11 +21,13 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('head_dim', [64, 128]) @pytest.mark.parametrize('head_dim', [64, 128])
# @pytest.mark.parametrize('head_dim', [64]) # @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096]) @pytest.mark.parametrize('embed_dim', [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024]) # @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, world_size, dtype): def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):
assert embed_dim % head_dim == 0 assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim num_heads = embed_dim // head_dim
assert num_heads % world_size == 0 assert num_heads % world_size == 0
...@@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype): ...@@ -38,7 +40,7 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 2
seqlen = 1024 seqlen = 1024
assert (batch_size * seqlen) % world_size == 0 assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype,
...@@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype): ...@@ -47,14 +49,17 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2), model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype) use_flash_attn=True, device=device, dtype=dtype)
partition_dim = embed_dim // world_size partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(), model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
with torch.no_grad(): with torch.no_grad():
model.Wqkv.weight.copy_( model.Wqkv.weight.copy_(
...@@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype): ...@@ -75,17 +80,22 @@ def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d') out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], x.grad,
rtol=rtol, atol=atol x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose( assert torch.allclose(
......
...@@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -19,14 +19,15 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) @pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [8]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias', [True, False]) @pytest.mark.parametrize('has_bias', [True, False])
# @pytest.mark.parametrize('has_bias', [True]) # @pytest.mark.parametrize('has_bias', [False])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize('out_features', [1024])
# @pytest.mark.parametrize('out_features', [1024]) @pytest.mark.parametrize('in_features', [4096])
@pytest.mark.parametrize('in_features', [1024, 4096]) def test_fused_linear_bias(in_features, out_features, has_bias, sequence_parallel,
# @pytest.mark.parametrize('in_features', [4096]) world_size, dtype):
def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtype):
assert out_features % world_size == 0 assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
...@@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp ...@@ -37,18 +38,21 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 2
seqlen = 512 seqlen = 512
assert batch_size * seqlen % world_size == 0 assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
requires_grad=True) requires_grad=True)
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
partition_out_features = out_features // world_size partition_out_features = out_features // world_size
model = ColumnParallelLinear(in_features, out_features, model = ColumnParallelLinear(in_features, out_features,
parallel_state.get_tensor_model_parallel_group(), bias=has_bias, parallel_state.get_tensor_model_parallel_group(), bias=has_bias,
device=device, dtype=dtype) sequence_parallel=sequence_parallel, device=device, dtype=dtype)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_( model.weight.copy_(
model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features] model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
...@@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp ...@@ -73,7 +77,9 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
...@@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp ...@@ -94,13 +100,14 @@ def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtyp
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2]) # @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize('has_bias2', [True, False]) @pytest.mark.parametrize('has_bias2', [True, False])
# @pytest.mark.parametrize('has_bias2', [True]) # @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('out_features', [1024]) @pytest.mark.parametrize('in_features', [1024])
@pytest.mark.parametrize('in_features', [1024, 4096]) def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel,
# @pytest.mark.parametrize('in_features', [1024]) world_size, dtype):
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size, dtype):
assert out_features % world_size == 0 assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
...@@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size ...@@ -111,7 +118,7 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 2
seqlen = 512 seqlen = 512
assert batch_size * seqlen % world_size == 0 assert batch_size * seqlen % world_size == 0
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
...@@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size ...@@ -120,7 +127,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
# as rank 0 will have an extra bias that changes the RNG. # as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32 g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() if sequence_parallel:
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
else:
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
...@@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size ...@@ -129,7 +139,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
partition_in_features = in_features // world_size partition_in_features = in_features // world_size
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features, model = ParallelFusedDenseGeluDense(in_features, out_features, in_features,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0, device=device, dtype=dtype) bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device, dtype=dtype)
with torch.no_grad(): with torch.no_grad():
model.fc1.weight.copy_( model.fc1.weight.copy_(
...@@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size ...@@ -148,16 +160,21 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
partition_batch_dim = batch_size * seqlen // world_size partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose( assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], out,
out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else out_pt,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
out_pt.backward(g) out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else g)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
assert torch.allclose( assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol rtol=rtol, atol=atol
) )
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment