Commit 5d647381 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

t5 regression fixes

parent a7a12f82
...@@ -114,13 +114,16 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -114,13 +114,16 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_logits.transpose(0,1).contiguous(), binary_logits return lm_logits.transpose(0,1).contiguous(), binary_logits
else: else:
# [b s] => [s b] # [b s] => [s b]
lm_logits = lm_logits.transpose(0,1).contiguous() lm_labels = lm_labels.transpose(0,1).contiguous()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits return lm_loss, binary_logits
......
...@@ -49,6 +49,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -49,6 +49,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = mpu.vocab_parallel_cross_entropy(output, labels)
else: else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss return loss
......
...@@ -152,7 +152,7 @@ class T5Model(MegatronModule): ...@@ -152,7 +152,7 @@ class T5Model(MegatronModule):
if self.post_process and self.add_decoder: if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
# Output. # Output. [s, b, h]
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight()) self.word_embeddings_weight())
...@@ -161,13 +161,15 @@ class T5Model(MegatronModule): ...@@ -161,13 +161,15 @@ class T5Model(MegatronModule):
return lm_logits.transpose(0,1).contiguous() return lm_logits.transpose(0,1).contiguous()
else: else:
# [b s] => [s b] # [b s] => [s b]
lm_labels = lm_lables.transpose(0,1).contiguous() lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy: if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss return lm_loss
elif self.add_decoder and not self.add_encoder: elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
......
...@@ -167,7 +167,6 @@ class SwitchMLP(MegatronModule): ...@@ -167,7 +167,6 @@ class SwitchMLP(MegatronModule):
class CoreAttention(MegatronModule): class CoreAttention(MegatronModule):
matmul_input_buffer = None
def __init__(self, layer_number, def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
...@@ -235,21 +234,16 @@ class CoreAttention(MegatronModule): ...@@ -235,21 +234,16 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
if CoreAttention.matmul_input_buffer is None: matmul_input_buffer = torch.empty(
CoreAttention.matmul_input_buffer = torch.empty( output_size[0]*output_size[1],
output_size[0]*output_size[1], output_size[2],
output_size[2], output_size[3],
output_size[3], dtype=query_layer.dtype,
dtype=query_layer.dtype, device=torch.cuda.current_device())
device=torch.cuda.current_device())
else:
assert CoreAttention.matmul_input_buffer.size() == \
(output_size[0]*output_size[1], output_size[2], output_size[3]), \
"buffer dimensions should remain the same during the training run"
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
CoreAttention.matmul_input_buffer, matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
...@@ -921,7 +915,7 @@ class ParallelTransformer(MegatronModule): ...@@ -921,7 +915,7 @@ class ParallelTransformer(MegatronModule):
if self.sequence_parallel: if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork() rng_context = mpu.get_cuda_rng_tracker().fork()
else: else:
rng_context = nullcontext rng_context = nullcontext()
with rng_context: with rng_context:
# Forward pass. # Forward pass.
......
...@@ -205,7 +205,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -205,7 +205,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop. fusion in backprop.
""" """
all_gather_buffer = None
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
...@@ -221,20 +220,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -221,20 +220,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
if LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer is None: all_gather_buffer = \
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer = \ torch.empty(dim_size, dtype=input.dtype,
torch.empty(dim_size, dtype=input.dtype, device=torch.cuda.current_device(),
device=torch.cuda.current_device(), requires_grad=False)
requires_grad=False)
else:
assert list(LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer.size()) == dim_size, \
"buffer dimensions should remain same during the training run"
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer, all_gather_buffer,
input, input,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer total_input = all_gather_buffer
else: else:
total_input = input total_input = input
...@@ -253,15 +247,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -253,15 +247,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
handle = torch.distributed._all_gather_base( handle = torch.distributed._all_gather_base(
LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer, all_gather_buffer,
input, input,
group=get_tensor_model_parallel_group(), async_op=True) group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have # Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated # gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
total_input = LinearWithGradAccumulationAndAsyncCommunication.all_gather_buffer total_input = all_gather_buffer
else: else:
total_input = input total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment