Commit 70174ae3 authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'memory_optimization' into 'master'

memory optimization in mpu cross entropy

See merge request ADLR/megatron-lm!32
parents ca8dd4ac 57064fd6
......@@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Copy so the input remains unchanged.
logits = vocab_parallel_logits.clone()
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(logits, dim=-1)[0]
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
# Subtract the maximum value.
logits.sub_(logits_max.unsqueeze(dim=-1))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = logits.exp()
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
......@@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = logits.view(-1, partition_vocab_size)
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
......@@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment