Commit 269f28f7 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

fixes to main merge

parent 6fdbf26b
...@@ -291,6 +291,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -291,6 +291,11 @@ def parse_args(extra_args_provider=None, defaults={},
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# model parallel memory optmization
if args.model_parallel_memory_opt:
assert not args.async_tensor_model_parallel_allreduce
_print_args(args) _print_args(args)
return args return args
......
...@@ -34,23 +34,21 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -34,23 +34,21 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Parallel logits. # Parallel logits.
if args.async_tensor_model_parallel_allreduce or\ if args.async_tensor_model_parallel_allreduce or\
args.model_parallel_memory_opt: args.model_parallel_memory_opt:
input_parallel = input input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel model_parallel
model_parallel_memory_opt = args.model_parallel_memory_opt and \
model_parallel
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
model_parallel_memory_opt = False
# Matrix multiply. # Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias, input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion, args.gradient_accumulation_fusion,
async_grad_allreduce, model_parallel_memory_opt) async_grad_allreduce, None)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
......
...@@ -881,7 +881,6 @@ class ParallelTransformer(MegatronModule): ...@@ -881,7 +881,6 @@ class ParallelTransformer(MegatronModule):
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params) inference_params=inference_params)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
...@@ -899,5 +898,4 @@ class ParallelTransformer(MegatronModule): ...@@ -899,5 +898,4 @@ class ParallelTransformer(MegatronModule):
else: else:
output = hidden_states output = hidden_states
return output return output
...@@ -49,7 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -49,7 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncAllreduce from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
......
...@@ -299,12 +299,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -299,12 +299,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if ctx.model_parallel_memory_opt: if ctx.model_parallel_memory_opt:
handle.wait() handle.wait()
return sub_grad_input, grad_weight, grad_bias return sub_grad_input, grad_weight, grad_bias, None, None, None
if ctx.async_grad_allreduce: if ctx.async_grad_allreduce:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias return grad_input, grad_weight, grad_bias, None, None, None
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -504,9 +504,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -504,9 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert not self.model_parallel_memory_opt assert not self.model_parallel_memory_opt
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncAllreduce.apply( output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None, input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None) self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.model_parallel_memory_opt: if self.model_parallel_memory_opt:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment