Commit 269f28f7 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

fixes to main merge

parent 6fdbf26b
......@@ -291,6 +291,11 @@ def parse_args(extra_args_provider=None, defaults={},
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# model parallel memory optmization
if args.model_parallel_memory_opt:
assert not args.async_tensor_model_parallel_allreduce
_print_args(args)
return args
......
......@@ -34,23 +34,21 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Parallel logits.
if args.async_tensor_model_parallel_allreduce or\
args.model_parallel_memory_opt:
input_parallel = input
input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel
model_parallel_memory_opt = args.model_parallel_memory_opt and \
model_parallel
else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
model_parallel_memory_opt = False
# Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, model_parallel_memory_opt)
async_grad_allreduce, None)
# Gather if needed.
if parallel_output:
return logits_parallel
......
......@@ -881,7 +881,6 @@ class ParallelTransformer(MegatronModule):
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
......@@ -899,5 +898,4 @@ class ParallelTransformer(MegatronModule):
else:
output = hidden_states
return output
......@@ -49,7 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
......
......@@ -299,12 +299,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if ctx.model_parallel_memory_opt:
handle.wait()
return sub_grad_input, grad_weight, grad_bias
return sub_grad_input, grad_weight, grad_bias, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias
return grad_input, grad_weight, grad_bias, None, None, None
class ColumnParallelLinear(torch.nn.Module):
......@@ -504,9 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert not self.model_parallel_memory_opt
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply(
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None)
self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions.
if self.model_parallel_memory_opt:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment