Commit 3893fa1a authored by Frank Lee's avatar Frank Lee
Browse files

[shardformer] refactored embedding and dropout to parallel module (#4013)

* [shardformer] refactored embedding and dropout to parallel module

* polish code
parent dfca9678
......@@ -3,6 +3,7 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.distributed import ProcessGroup
class DistCrossEntropy(Function):
......@@ -14,7 +15,7 @@ class DistCrossEntropy(Function):
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
......@@ -34,15 +35,15 @@ class DistCrossEntropy(Function):
"""
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group)
# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank()
world_size = dist.get_world_size()
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
global_vocab_size = partition_vocab_size * world_size
# [down, up) => false, other device and -100 => true
......@@ -67,11 +68,11 @@ class DistCrossEntropy(Function):
pred_logits[mask] = 0.0
# allreduce the get all x(i,y)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
......@@ -101,5 +102,8 @@ class DistCrossEntropy(Function):
return grad_logits, None, None
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)
def cross_entropy_1d(vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
process_group: ProcessGroup = None) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
from typing import List, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from .layers import ParallelModule
from .utils import create_randomizer_with_offset
class Dropout1D(nn.Dropout):
class Dropout1D(ParallelModule, nn.Dropout):
"""
The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with
randomness on different ranks of the given process group. This can avoid the same dropout mask is generated
and applied on the same position of different ranks, leading to poor convergence performance.
Args:
p (float): probability of an element to be zeroed. Defaults to 0.5.
inplace (bool): If set to True, will do this operation in-place. Defaults to False.
process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None.
"""
def __init__(self, p=0.5, inplace=False, process_group=None):
super().__init__(p, inplace)
def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None):
# init with nn.Dropout
super(nn.Dropout, self).__init__(p=p, inplace=inplace)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=process_group)
@staticmethod
def from_native_module(module: nn.Dropout,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D":
"""
Create a Dropout1D layer from a native dropout layer.
"""
p = module.p
inplace = module.inplace
return Dropout1D(p=p, inplace=inplace, process_group=process_group)
def forward(self, input):
with self.randomizer.fork_rng():
input = super().forward(input)
......
This diff is collapsed.
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
def check_dropout():
dropout = nn.Dropout().cuda()
dropout_1d = Dropout1D.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
# we set seed so that dropout will generate the same mask
torch.cuda.manual_seed(1024)
out = dropout(x)
# we set seed to simulate the same scenario
# but expect the dropout mask to be different
# due to the internal randomness control
torch.cuda.manual_seed(1024)
out_1d = dropout_1d(x)
# ensure out is the same across all ranks
world_size = dist.get_world_size()
out_all = [torch.empty_like(out) for _ in range(world_size)]
dist.all_gather(out_all, out)
for i in range(world_size):
assert_equal(out_all[i], out_all[0])
# ensure out_1d is different across ranks
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
dist.all_gather(out_1d_all, out_1d)
for i in range(1, world_size):
assert_not_equal(out_1d_all[i], out_1d_all[0])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_dropout()
@rerun_if_address_is_in_use()
def test_dropout():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
test_dropout()
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer.layers import Embedding1D
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_embedding_1d():
embedding = nn.Embedding(32, 128).cuda()
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
assert embedding_1d.weight.shape == torch.Size([32, 64])
# check computation correctness
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
out = embedding(x)
gather_out = embedding_1d(x)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(embedding.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, embedding_1d.weight.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_embedding_1d()
@rerun_if_address_is_in_use()
def test_embedding_1d():
spawn(run_dist, nprocs=2)
if __name__ == '__main__':
test_embedding_1d()
......@@ -5,7 +5,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_linear_1d_col():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment