"...git@developer.sourcefind.cn:modelzoo/pasd_pytorch.git" did not exist on "d7c40c9a1be0c64234b259a745bd81ecd9f1eb59"
Commit 015af592 authored by Frank Lee's avatar Frank Lee
Browse files

[shardformer] integrated linear 1D with dtensor (#3996)

* [shardformer] integrated linear 1D with dtensor

* polish code
parent d3bc5308
......@@ -10,6 +10,7 @@ from colossalai.core import global_context as gpc
class ParallelLayer(nn.Module):
global_state_dict: bool = True
def __init__(self):
......
......@@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
......@@ -74,12 +74,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
......@@ -93,5 +94,123 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
class _ReduceInput(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@staticmethod
def forward(ctx, input_, process_group):
return _reduce(input_, process_group)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
dist.all_reduce(input_, group=process_group)
return input_
def _split(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# all gather
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
def reduce_input(input_, process_group):
return _ReduceInput.apply(input_, process_group)
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
import torch.nn as nn
class SeedManager:
"""
This class is a random state manager to change random state for different random seed.
"""
def __init__(self):
original_state = torch.cuda.get_rng_state()
# TODO: unify this seed manager with the colossalai.context.random
seed = os.getpid()
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)
def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)
def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state
@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)
_seed_manager = SeedManager()
from .utils import create_randomizer_with_offset
class Dropout1D(nn.Dropout):
def __init__(self, p=0.5, inplace=False):
def __init__(self, p=0.5, inplace=False, process_group=None):
super().__init__(p, inplace)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)
def forward(self, input):
with _seed_manager.dropout_mode():
with self.randomizer.fork_rng():
input = super().forward(input)
return input
This diff is collapsed.
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Randomizer:
"""
Randomizer enables the program to be executed under a different seed within the context.
Example:
```python
randomizer = Randomizer(seed=1024)
with randomizer.fork():
# do something here with seed 1024
do_something()
```
Args:
seed (int): The random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
with_index (bool): whether to use the index of the randomizer.
"""
_INDEX = 0
def __init__(self, seed: int):
# TODO: remove colossalai.context.random
self.seed = seed
# Handle CUDA rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
cuda_original_rng_state = torch.cuda.get_rng_state()
torch.cuda.manual_seed(seed)
self.cuda_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(cuda_original_rng_state)
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
torch.manual_seed(seed)
self.cpu_rng_state = torch.get_rng_state()
torch.set_rng_state(cpu_original_rng_state)
def _set_cuda_rng_state(self, rng_state):
torch.cuda.set_rng_state(rng_state)
def _get_cuda_rng_state(self):
current_state = torch.cuda.get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
torch.set_rng_state(rng_state)
def _get_cpu_rng_state(self):
current_state = torch.get_rng_state()
return current_state
@contextmanager
def fork_rng(self, enable_cpu: bool = False):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_cuda_rng_state = self._get_cuda_rng_state()
self._set_cuda_rng_state(self.cuda_rng_state)
if enable_cpu:
current_cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(self.cpu_rng_state)
yield
finally:
self.cuda_rng_state = self._get_cuda_rng_state()
self._set_cuda_rng_state(current_cuda_rng_state)
if enable_cpu:
self.cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(current_cpu_rng_state)
@staticmethod
def index():
"""
Return the index of the randomizer. The index is useful when the user wants
to introduce some randomness in the program.
Note:
The index will increment by one each time this method is called.
Example:
```python
# assume we need a randomizer to init the weight of different layers
# we can use the index of the randomizer to do so that
# each layer has its own randomizer with a different seed
base_seed = torch.random.initial_seed()
seed = base_seed + Randomizer.index()
randomizer = Randomizer(seed)
with randomizer.fork():
init_weights()
```
"""
idx = Randomizer._INDEX
Randomizer._INDEX += 1
return idx
def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None):
"""
Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer.
Args:
seed (int): The base random seed to set.
enable_cpu (bool): fork the CPU RNG state as well.
process_group (ProcessGroup): the process group to get the rank from.
Returns:
Randomizer: the randomizer with offset.
"""
offset = Randomizer.index()
if dist.is_initialized():
rank = dist.get_rank(process_group)
offset += rank
seed += offset
return Randomizer(seed=seed)
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from .d_tensor import DTensor
from .sharding_spec import ShardingSpec
def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
"""
Shard the first dim of the given tensor
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
return DTensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
"""
Shard the first dim of the given tensor
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
return DTensor(tensor, device_mesh, sharding_spec)
......@@ -34,7 +34,7 @@ class Layout:
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
......@@ -45,14 +45,15 @@ class Layout:
sharding_spec = self.sharding_spec
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
if self.device_mesh.logical_mesh_id is not None:
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
......@@ -60,7 +61,7 @@ class Layout:
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
......
......@@ -304,7 +304,7 @@ class LayoutConverter(metaclass=SingletonMeta):
process_groups_dict = source_layout.device_mesh.process_groups_dict
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
......
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_linear_1d_col():
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
assert linear_col.weight.shape == torch.Size([64, 32])
assert linear_col.bias.shape == torch.Size([64])
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_col(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_col.weight.grad)
def check_linear_1d_row():
linear = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, linear_row.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
check_linear_1d_row()
@rerun_if_address_is_in_use()
def test_linear():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
test_linear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment