Commit 5d4689c4 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

tensor model parallelism memory optmization

parent 0ed2f6ac
...@@ -530,6 +530,10 @@ def _add_training_args(parser): ...@@ -530,6 +530,10 @@ def _add_training_args(parser):
'This kernel supports only a set of hidden sizes. Please ' 'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden ' 'check persist_ln_hidden_sizes if your hidden '
'size is supported.') 'size is supported.')
group.add_argument('--model-parallel-memory-opt', action='store_true',
help='Enable model parallel memory optmization.')
return parser return parser
......
...@@ -40,7 +40,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -40,7 +40,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_along_last_dim_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
......
...@@ -628,6 +628,8 @@ class ParallelTransformer(MegatronModule): ...@@ -628,6 +628,8 @@ class ParallelTransformer(MegatronModule):
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
self.model_parallel_memory_opt = args.model_parallel_memory_opt
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
...@@ -771,6 +773,10 @@ class ParallelTransformer(MegatronModule): ...@@ -771,6 +773,10 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is. # Otherwise, leave it as is.
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
if self.model_parallel_memory_opt:
hidden_states = mpu.scatter_along_first_dim_to_tensor_model_parallel_region(hidden_states)
else: else:
# See set_input_tensor() # See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
...@@ -820,9 +826,14 @@ class ParallelTransformer(MegatronModule): ...@@ -820,9 +826,14 @@ class ParallelTransformer(MegatronModule):
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = self.final_layernorm(hidden_states)
output = self.final_layernorm(hidden_states)
if self.model_parallel_memory_opt:
hidden_states = mpu.gather_along_first_dim_from_tensor_model_parallel_region(hidden_states)
output = hidden_states.transpose(0, 1).contiguous()
else: else:
output = hidden_states output = hidden_states
return output return output
...@@ -55,11 +55,15 @@ from .layers import VocabParallelEmbedding ...@@ -55,11 +55,15 @@ from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes, set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from .mappings import gather_along_last_dim_from_tensor_model_parallel_region
from .mappings import scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import gather_along_first_dim_from_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_last_dim_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
......
...@@ -29,9 +29,12 @@ from .initialize import get_tensor_model_parallel_rank ...@@ -29,9 +29,12 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_along_first_dim_from_tensor_model_parallel_region
from .mappings import gather_along_last_dim_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from .mappings import scatter_along_last_dim_to_tensor_model_parallel_region
from .mappings import reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -307,6 +310,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -307,6 +310,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and not args.no_async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.model_parallel_memory_opt = args.model_parallel_memory_opt
...@@ -323,14 +327,18 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -323,14 +327,18 @@ class ColumnParallelLinear(torch.nn.Module):
input_shape[0], input_shape[1], output_parallel.shape[1]) input_shape[0], input_shape[1], output_parallel.shape[1])
else: else:
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_) if self.model_parallel_memory_opt:
input_parallel = gather_along_first_dim_from_tensor_model_parallel_region(input_)
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) assert not self.model_parallel_memory_opt
output = gather_along_last_dim_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
...@@ -416,6 +424,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -416,6 +424,7 @@ class RowParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.model_parallel_memory_opt = args.model_parallel_memory_opt
def forward(self, input_): def forward(self, input_):
...@@ -423,11 +432,15 @@ class RowParallelLinear(torch.nn.Module): ...@@ -423,11 +432,15 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) assert not self.model_parallel_memory_opt
input_parallel = scatter_along_last_dim_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel) if self.model_parallel_memory_opt:
output_ = reduce_scatter_along_first_dim_to_tensor_model_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
output_bias = None output_bias = None
......
...@@ -32,7 +32,8 @@ def _reduce(input_): ...@@ -32,7 +32,8 @@ def _reduce(input_):
return input_ return input_
def _split(input_):
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
...@@ -50,8 +51,28 @@ def _split(input_): ...@@ -50,8 +51,28 @@ def _split(input_):
return output return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * (local_dim_size)
output = input_[dim_offset:dim_offset+local_dim_size]
return output
def _gather(input_):
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
...@@ -73,6 +94,54 @@ def _gather(input_): ...@@ -73,6 +94,54 @@ def _gather(input_):
return output return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
torch.distributed._all_gather_base(output, input_,
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0
dim_size[0]= dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
torch.distributed._reduce_scatter_base(output, input_,
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_last_dim(input_):
output = _reduce(input_)
output = _split_along_last_dim(output)
return output
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
...@@ -105,36 +174,100 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -105,36 +174,100 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
return grad_output return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function): class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _split(input_) return _split_along_last_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split_along_last_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _gather(grad_output) return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """Gather the input from model parallel region and concatinate."""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _gather(input_) return _gather_along_last_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather_along_last_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _split(grad_output) return _reduce_scatter_along_last_dim(grad_output)
class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" #TODO
@staticmethod
def symbolic(graph, input_):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter_along_first_dim(grad_output)
class _ReduceScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _ReduceScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# ----------------- # -----------------
...@@ -149,9 +282,25 @@ def reduce_from_tensor_model_parallel_region(input_): ...@@ -149,9 +282,25 @@ def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_): def scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterAlongLastDimToModelParallelRegion.apply(input_)
def gather_along_last_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongLastDimFromModelParallelRegion.apply(input_)
def scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ScatterAlongFirstDimToModelParallelRegion.apply(input_)
def gather_along_first_dim_from_tensor_model_parallel_region(input_):
return _GatherAlongFirstDimFromModelParallelRegion.apply(input_)
def reduce_scatter_along_first_dim_to_tensor_model_parallel_region(input_):
return _ReduceScatterAlongFirstDimToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_): def reduce_scatter_along_last_dim_to_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _ReduceScatterAlongLastDimToModelParallelRegion.apply(input_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment