Commit 343492ec authored by Tri Dao's avatar Tri Dao
Browse files

Make nccl operations async in CrossEntropyLossParallel

parent 3dda4f76
...@@ -36,40 +36,50 @@ class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): ...@@ -36,40 +36,50 @@ class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function):
assert labels.shape == (batch,) assert labels.shape == (batch,)
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
partition_vocab_size, get_tensor_model_parallel_rank(),
get_tensor_model_parallel_world_size()
)
# Create a mask of valid vocab ids (1 means it needs to be masked). if world_size == 1:
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) losses, lse = xentropy_cuda_lib.forward(logits_parallel, labels, smoothing)
ignored_mask = labels == ignored_index losses.masked_fill_(labels==ignored_index, 0)
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) labels_local = labels
masked_labels = labels_local.clone() else:
masked_labels[labels_mask] = ignored_index vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
partition_vocab_size, get_tensor_model_parallel_rank(),
get_tensor_model_parallel_world_size()
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
masked_labels = labels_local.clone()
masked_labels[labels_mask] = ignored_index
losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing)
assert lse_local.shape == (batch,) assert lse_local.shape == (batch,)
assert losses.shape == (batch,) assert losses.shape == (batch,)
losses.masked_fill_(masked_labels==ignored_index, 0) losses.masked_fill_(masked_labels==ignored_index, 0)
if world_size > 1:
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
device=lse_local.device) device=lse_local.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), handle_lse = torch.distributed.all_gather_into_tensor(
group=get_tensor_model_parallel_group()) lse_allgather, lse_local.contiguous(),
group=get_tensor_model_parallel_group(), async_op=True
)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(), async_op=True
)
handle_lse.wait()
lse = torch.logsumexp(lse_allgather, dim=0) lse = torch.logsumexp(lse_allgather, dim=0)
torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, # The losses are going to be lse_local - predicted_logit, we just have to subtract
group=get_tensor_model_parallel_group()) # the lse_local and add the lse (global).
# The losses are currently lse_local - predicted_logit, we just have to subtract the rank_per_sample = torch.div(labels, partition_vocab_size, rounding_mode='floor')
# lse_local and add the lse (global).
rank_per_sample = labels // partition_vocab_size
lse_local = lse_allgather[rank_per_sample, lse_local = lse_allgather[rank_per_sample,
torch.arange(batch, device=lse_allgather.device)] torch.arange(batch, device=lse_allgather.device)]
handle_losses.wait()
losses += lse - lse_local losses += lse - lse_local
losses.masked_fill_(ignored_mask, 0) losses.masked_fill_(ignored_mask, 0)
else:
lse = lse_local
ctx.save_for_backward(logits_parallel, lse, labels_local) ctx.save_for_backward(logits_parallel, lse, labels_local)
ctx.smoothing = smoothing ctx.smoothing = smoothing
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from einops import rearrange from einops import rearrange
from src.losses.cross_entropy_apex import CrossEntropyLossApex from flass_attn.losses.cross_entropy_apex import CrossEntropyLossApex
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
......
...@@ -10,7 +10,7 @@ import pytest ...@@ -10,7 +10,7 @@ import pytest
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer import tensor_parallel from apex.transformer import tensor_parallel
from src.losses.cross_entropy_parallel import CrossEntropyLossParallel from flash_attn.losses.cross_entropy_parallel import CrossEntropyLossParallel
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment