Commit 5d4689c4 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

tensor model parallelism memory optmization

parent 0ed2f6ac
......@@ -530,6 +530,10 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--model-parallel-memory-opt', action='store_true',
help='Enable model parallel memory optmization.')
return parser
......
......@@ -40,7 +40,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output:
return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
return mpu.gather_along_last_dim_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler,
......
......@@ -628,6 +628,8 @@ class ParallelTransformer(MegatronModule):
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
self.model_parallel_memory_opt = args.model_parallel_memory_opt
# Number of layers.
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
......@@ -771,6 +773,10 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
hidden_states = mpu.scatter_along_first_dim_to_tensor_model_parallel_region(hidden_states)
else:
# See set_input_tensor()
hidden_states = self.input_tensor
......@@ -820,9 +826,14 @@ class ParallelTransformer(MegatronModule):
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
hidden_states = self.final_layernorm(hidden_states)
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_along_first_dim_from_tensor_model_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous()
else:
output = hidden_states
return output
......@@ -55,11 +55,15 @@ from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region
from .mappings import scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_last_dim_to_tensor_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
......
......@@ -29,9 +29,12 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -307,6 +310,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and
world_size > 1)
self.model_parallel_memory_opt = args.model_parallel_memory_opt
......@@ -323,14 +327,18 @@ class ColumnParallelLinear(torch.nn.Module):
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.model_parallel_memory_opt:
input_parallel = gather_along_first_dim_from_tensor_model_parallel_region(input_)
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
assert not self.model_parallel_memory_opt
output = gather_along_last_dim_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
......@@ -416,6 +424,7 @@ class RowParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
self.model_parallel_memory_opt = args.model_parallel_memory_opt
def forward(self, input_):
......@@ -423,11 +432,15 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
assert not self.model_parallel_memory_opt
input_parallel = scatter_along_last_dim_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if self.model_parallel_memory_opt:
output_ = reduce_scatter_along_first_dim_to_tensor_model_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
......@@ -32,7 +32,8 @@ def _reduce(input_):
return input_
def _split(input_):
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
......@@ -50,8 +51,28 @@ def _split(input_):
return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * (local_dim_size)
output = input_[dim_offset:dim_offset+local_dim_size]
return output
def _gather(input_):
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
......@@ -73,6 +94,54 @@ def _gather(input_):
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(output, input_,
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0
dim_size[0]= dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
torch.distributed._reduce_scatter_base(output, input_,
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_last_dim(input_):
output = _reduce(input_)
output = _split_along_last_dim(output)
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
......@@ -105,36 +174,100 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
return _reduce_scatter_along_last_dim(grad_output)
class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO
@staticmethod
def symbolic(graph, input_):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter_along_first_dim(grad_output)
class _ReduceScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _ReduceScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
......@@ -149,9 +282,25 @@ def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _ScatterAlongLastDimToModelParallelRegion.apply(input_)
def gather_along_last_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongLastDimFromModelParallelRegion.apply(input_)
def scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ScatterAlongFirstDimToModelParallelRegion.apply(input_)
def gather_along_first_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongFirstDimFromModelParallelRegion.apply(input_)
def reduce_scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongFirstDimToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def reduce_scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongLastDimToModelParallelRegion.apply(input_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment