Unverified Commit 424629fe authored by Bin Jia's avatar Bin Jia Committed by GitHub
Browse files

[shardformer/sequence parallel] Cherry pick commit to new branch (#4450)

* [shardformer/sequence parallel] Support sequence parallel for gpt2 (#4384)

* [sequence parallel] add sequence parallel linear col/row support (#4336)

* add sequence parallel linear col/row support

* add annotation

* add annotation

* add support for gpt2 fused qkv linear layer

* support sequence parallel in GPT2

* add docstring and note

* add requirments

* remove unused flash-attb

* modify flash attn test

* modify flash attn setting

* modify flash attn code

* add assert before divide, rename forward function

* [shardformer/test] fix gpt2 test with seq-parallel

* [shardformer/sequence parallel] Overlap input gather and grad computation during col backward (#4401)

* overlap gather input / grad computing during col backward

* modify test for overlap

* simplify code

* fix code and modify cuda stream synchronize

* [shardformer/sequence parallel] polish code
parent d20dceb9
...@@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -152,6 +152,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_fused_normalization: bool = False, enable_fused_normalization: bool = False,
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
...@@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -178,6 +179,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_fused_normalization = enable_fused_normalization self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
...@@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -195,7 +197,8 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_all_optimization=self.enable_all_optimization, enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization, enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention, enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused) enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
......
from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
...@@ -141,6 +143,215 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ...@@ -141,6 +143,215 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
# create new stream for calculate the gradient
calculate_stream = torch.cuda.Stream()
# do all gather in default stream
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient in calculate_stream
with torch.cuda.stream(calculate_stream):
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
torch.cuda.current_stream().wait_stream(calculate_stream)
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
print(grad_output.shape, input_parallel.shape)
grad_weight = grad_output.t().matmul(input_parallel)
torch.cuda.current_stream().wait_stream(calculate_stream)
return output, grad_weight, grad_bias, None, None, None, None
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.dim = dim
ctx.process_group = process_group
# do reduce-scatter
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
return _gather(grad_output, dim, process_group), None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
This class is designed for matmul operation with gather forward and reduce-scatter backward.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
Split the input and keep only the corresponding chuck to the rank. Split the input and keep only the corresponding chuck to the rank.
...@@ -200,6 +411,26 @@ class _ReduceBackward(torch.autograd.Function): ...@@ -200,6 +411,26 @@ class _ReduceBackward(torch.autograd.Function):
return _reduce(grad_output, ctx.process_group), None return _reduce(grad_output, ctx.process_group), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def _reduce(input_, process_group): def _reduce(input_, process_group):
# skip if only one rank involved # skip if only one rank involved
if dist.get_world_size(process_group) == 1: if dist.get_world_size(process_group) == 1:
...@@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -235,6 +466,7 @@ def _gather(input_, dim=-1, process_group=None):
return input_ return input_
# all gather # all gather
input_ = input_.contiguous()
rank = dist.get_rank(process_group) rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
...@@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None): ...@@ -246,24 +478,27 @@ def _gather(input_, dim=-1, process_group=None):
return output return output
class _GatherForwardSplitBackward(torch.autograd.Function): def _reduce_scatter(input_, dim=1, process_group=None):
"""Gather the input from model parallel region and concatenate. """ Do reduce-scatter operation.
Args: Args:
input_: input matrix. input_ (`torch.Tensor`): The input tensor from sequence parallel region.
parallel_mode: parallel mode. dim (int): The dimension to perform reduce-scatter.
dim: dimension process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
""" """
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
@staticmethod # reduce-scatter
def forward(ctx, input_, dim, process_group): new_shape = list(input_.shape)
ctx.process_group = process_group assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
ctx.dim = dim f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
return _gather(input_, dim, process_group) new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)
@staticmethod return output
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
...@@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre ...@@ -274,6 +509,21 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim)
def gather_forward_split_backward(input_, dim, process_group): def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group) return _GatherForwardSplitBackward.apply(input_, dim, process_group)
......
...@@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import ( ...@@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import ( from ._operation import (
gather_forward_split_backward, gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
reduce_forward, reduce_forward,
split_forward_gather_backward, split_forward_gather_backward,
...@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule): ...@@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`): weight_initializer (`typing.Callable`):
...@@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule): ...@@ -69,6 +73,8 @@ class Linear1D_Col(ParallelModule):
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
gather_output: bool = False, gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
...@@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule): ...@@ -80,6 +86,8 @@ class Linear1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
...@@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule): ...@@ -180,7 +188,11 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
...@@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule): ...@@ -203,6 +215,8 @@ class Linear1D_Row(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
...@@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule): ...@@ -221,6 +235,7 @@ class Linear1D_Row(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
seq_parallel: bool = False,
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
...@@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule): ...@@ -238,6 +253,7 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
...@@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule): ...@@ -373,7 +389,10 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = F.linear(input_, self.weight) output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group) if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:
......
...@@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import ( ...@@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import ( from ._operation import (
gather_forward_split_backward, gather_forward_split_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm, matmul_with_async_comm,
reduce_backward, reduce_backward,
reduce_forward, reduce_forward,
...@@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV). n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
...@@ -173,6 +176,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -173,6 +176,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False, async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3, n_fused: int = 3,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
...@@ -185,6 +189,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -185,6 +189,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.n_fused = n_fused self.n_fused = n_fused
...@@ -296,15 +301,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): ...@@ -296,15 +301,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
assert input_.shape[-1] == self.weight.shape[0], \ assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1]) input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
# input_parallel = input_
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, if self.seq_parallel:
self.async_communication) input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1)
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
self.async_communication)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
...@@ -329,6 +338,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -329,6 +338,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
...@@ -346,6 +356,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -346,6 +356,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
seq_parallel: bool = False,
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
...@@ -363,6 +374,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -363,6 +374,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
...@@ -499,7 +511,10 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): ...@@ -499,7 +511,10 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group) if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:
......
# this code is modified from transformers.models.gpt2.modeling_gpt2
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
# TODO: put all contents in `gpt2.py` and make it compatible with pipeline
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
...@@ -11,17 +11,12 @@ from torch.nn import Module ...@@ -11,17 +11,12 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig from ..shard.shard_config import ShardConfig
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
class ParallelModule():
def __init__(self):
pass
@dataclass @dataclass
class SubModuleReplacementDescription: class SubModuleReplacementDescription:
r""" r"""
...@@ -231,3 +226,22 @@ class Policy(ABC): ...@@ -231,3 +226,22 @@ class Policy(ABC):
end_idx = num_layers_per_stage_accumulated[stage + 1] end_idx = num_layers_per_stage_accumulated[stage + 1]
return [start_idx, end_idx] return [start_idx, end_idx]
def append_seq_parallel_to_policy(
self,
suffix_list: List[str],
module_policy_description: ModulePolicyDescription,
):
r"""
Append the sequence parallel policy to the policy for the given key.
Args:
suffix_list (List[str]): the suffix list of the module to be parallelized
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
"""
for sub_description in module_policy_description.sub_module_replacement:
if (sub_description.suffix in suffix_list):
if sub_description.kwargs is None:
sub_description.kwargs = {}
sub_description.kwargs["seq_parallel"] = True
...@@ -7,6 +7,7 @@ import colossalai.shardformer.layer as col_nn ...@@ -7,6 +7,7 @@ import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -49,6 +50,9 @@ class GPT2Policy(Policy): ...@@ -49,6 +50,9 @@ class GPT2Policy(Policy):
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
]) ])
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
...@@ -120,6 +124,11 @@ class GPT2Policy(Policy): ...@@ -120,6 +124,11 @@ class GPT2Policy(Policy):
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={ policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_gpt2_flash_attention_forward(), 'forward': get_gpt2_flash_attention_forward(),
}) })
if self.shard_config.enable_sequence_parallelism:
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -28,6 +28,7 @@ class ShardConfig: ...@@ -28,6 +28,7 @@ class ShardConfig:
enable_all_optimization: bool = False enable_all_optimization: bool = False
enable_flash_attention: bool = False enable_flash_attention: bool = False
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
......
...@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): ...@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor return rearanged_tensor
@parameterize('lazy_init', [False, True]) def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool):
def check_linear_conv_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
with ctx: with ctx:
...@@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool): ...@@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None, process_group=None,
gather_output=True, gather_output=True,
seq_parallel=seq_parallel,
n_fused=3) n_fused=3)
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])
...@@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool): ...@@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear.load_state_dict(linear_conv_col.state_dict()) linear.load_state_dict(linear_conv_col.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 48).cuda() x = torch.rand(1, 4, 48).cuda()
out = linear(x) out = linear(x)
gather_out = linear_conv_col(x) x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
assert_close(rearrange(out, 1), gather_out) gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
...@@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool): ...@@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_conv_col.weight.grad)
@parameterize('lazy_init', [False, True]) def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
def check_linear_conv_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.weight.shape == torch.Size([24, 192])
...@@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool): ...@@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict()) linear.load_state_dict(linear_row.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 48).cuda() x = torch.rand(1, 4, 48).cuda()
out = linear(x) out = linear(x)
gather_out = linear_row(x) gather_out = linear_row(x)
assert_close(out, gather_out) target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
...@@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool): ...@@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool):
assert_close(target_grad, linear_row.weight.grad) assert_close(target_grad, linear_row.weight.grad)
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel)
check_linear_conv_1d_row(lazy_init, seq_parallel)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv # test for linear conv
check_linear_conv_1d_col() check_gpt2_qkv_fused_linear_1d()
check_linear_conv_1d_row()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
......
...@@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor ...@@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize('lazy_init', [False, True]) def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
with ctx: with ctx:
linear_copy = nn.Linear(32, 128).cuda() linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) linear_col = Linear1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
overlap=overlap)
# ensure that the parameters are distributed # ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight) assert is_distributed_tensor(linear_col.weight)
...@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool): ...@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
linear_col.load_state_dict(linear.state_dict()) linear_col.load_state_dict(linear.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 32).cuda() # [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone()) x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True) x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True) x_for_shard.requires_grad_(True)
out = linear(x_for_unshard) out = linear(x_for_unshard)
...@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool): ...@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
# check the input gradients # check the input gradients
assert x_for_shard.grad is not None assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad) target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize('lazy_init', [False, True]) def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
with ctx: with ctx:
linear_copy = nn.Linear(32, 128).cuda() linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) linear_row = Linear1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128]) assert linear_row.bias.shape == torch.Size([128])
...@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool): ...@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
linear_row.load_state_dict(linear.state_dict()) linear_row.load_state_dict(linear.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 32).cuda() # [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone()) x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True) x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) x_for_shard = x.expand_as(x.clone())
...@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool): ...@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
# run forward # run forward
out = linear(x_for_unshard) out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard) gather_out = linear_row(x_for_shard)
assert_close(out, gather_out) target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness # check backward correctness
out.sum().backward() out.sum().backward()
...@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool): ...@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad) assert_close(x_for_unshard.grad, x_for_shard.grad)
@parameterize('lazy_init', [False, True]) def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
def check_linear_col_plus_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda() linear_1 = nn.Linear(32, 128).cuda()
...@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool): ...@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
with ctx: with ctx:
linear_1_copy = nn.Linear(32, 128).cuda() linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda() linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) linear_col = Linear1D_Col.from_native_module(linear_1_copy,
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) process_group=None,
gather_output=False,
seq_parallel=seq_parallel,
overlap=overlap)
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
process_group=None,
parallel_input=True,
seq_parallel=seq_parallel)
linear_1.load_state_dict(linear_col.state_dict()) linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict()) linear_col.load_state_dict(linear_1.state_dict())
...@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool): ...@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
linear_row.load_state_dict(linear_2.state_dict()) linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 32).cuda() # [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone()) x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True) x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True) x_for_shard.requires_grad_(True)
# run forward # run forward
unshard_out = linear_2(linear_1(x_for_unshard)) unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard)) shard_out = linear_row(linear_col(x_for_shard))
assert_close(unshard_out, shard_out) target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, shard_out)
# check backward correctness # check backward correctness
unshard_out.sum().backward() unshard_out.sum().backward()
...@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool): ...@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
# check the input gradients # check the input gradients
assert x_for_shard.grad is not None assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad) target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [False, True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
def run_dist(rank, world_size, port): def check_dist_linear(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col() run_dist_linear_test()
check_linear_1d_row()
check_linear_col_plus_row()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear(): def test_linear():
spawn(run_dist, nprocs=2) spawn(check_dist_linear, nprocs=2)
if __name__ == '__main__': if __name__ == '__main__':
......
import copy import copy
import math
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -25,6 +26,7 @@ def build_model(model_fn, ...@@ -25,6 +26,7 @@ def build_model(model_fn,
enable_tensor_parallelism=True, enable_tensor_parallelism=True,
enable_flash_attention=False, enable_flash_attention=False,
enable_jit_fused=False, enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False): use_lazy_init: bool = False):
# create new model # create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext() ctx = LazyInitContext() if use_lazy_init else nullcontext()
...@@ -38,7 +40,8 @@ def build_model(model_fn, ...@@ -38,7 +40,8 @@ def build_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism, enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention, enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused) enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy) sharded_model, shared_params = shard_former.optimize(model_copy)
...@@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo ...@@ -135,6 +138,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
return loss return loss
data = data_gen_fn() data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
sharded_model.train() sharded_model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
for k, v in data.items(): for k, v in data.items():
......
...@@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -106,6 +106,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32', 'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': False,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}]) }])
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment