Commit 5d647381 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

t5 regression fixes

parent a7a12f82
......@@ -114,13 +114,16 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_logits.transpose(0,1).contiguous(), binary_logits
else:
# [b s] => [s b]
lm_logits = lm_logits.transpose(0,1).contiguous()
lm_labels = lm_labels.transpose(0,1).contiguous()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits
......
......@@ -49,6 +49,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss
......
......@@ -152,7 +152,7 @@ class T5Model(MegatronModule):
if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output.
# Output. [s, b, h]
lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight())
......@@ -161,13 +161,15 @@ class T5Model(MegatronModule):
return lm_logits.transpose(0,1).contiguous()
else:
# [b s] => [s b]
lm_labels = lm_lables.transpose(0,1).contiguous()
lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
......
......@@ -167,7 +167,6 @@ class SwitchMLP(MegatronModule):
class CoreAttention(MegatronModule):
matmul_input_buffer = None
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
......@@ -235,21 +234,16 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
if CoreAttention.matmul_input_buffer is None:
CoreAttention.matmul_input_buffer = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
else:
assert CoreAttention.matmul_input_buffer.size() == \
(output_size[0]*output_size[1], output_size[2], output_size[3]), \
"buffer dimensions should remain the same during the training run"
matmul_input_buffer = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
CoreAttention.matmul_input_buffer,
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
......@@ -921,7 +915,7 @@ class ParallelTransformer(MegatronModule):
if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext
rng_context = nullcontext()
with rng_context:
# Forward pass.
......
......@@ -205,7 +205,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
all_gather_buffer = None
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
......@@ -221,20 +220,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
if LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer is None:
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
else:
assert list(LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer.size()) == dim_size, \
"buffer dimensions should remain same during the training run"
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer
total_input = all_gather_buffer
else:
total_input = input
......@@ -253,15 +247,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer,
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment