Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
......@@ -15,9 +15,9 @@
import torch
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_src_rank
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4
......@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_model_parallel_rank() == 0:
if get_tensor_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
......@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(),
group=get_model_parallel_group())
torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
......@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data)
# Pack on rank zero.
if get_model_parallel_rank() == 0:
if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
......@@ -100,9 +100,9 @@ def broadcast_data(keys, data, datatype):
device=torch.cuda.current_device(),
dtype=datatype)
# Boradcast
torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(),
group=get_model_parallel_group())
# Broadcast
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
# Unpack
output = {}
......
......@@ -28,8 +28,9 @@ try:
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_tensor_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
......@@ -43,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_model_parallel_rank() == 0)
mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
......@@ -71,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm):
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2):
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
......@@ -90,13 +91,27 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
......@@ -105,16 +120,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
#elif norm_type == 2:
# total_norm = l2_grad_clipper(parameters, max_norm)
else:
total_norm = 0
for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0):
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
......
......@@ -21,75 +21,148 @@ import torch
from .utils import ensure_divisibility
# Model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None
_MPU_RANK = None
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_):
def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1):
"""
Initialize model data parallel groups.
Arguments:
model_parallel_size: number of GPUs used to parallelize model.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel grous as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print('> initializing model parallel with size {}'.format(
model_parallel_size_))
print('> initializing tensor model parallel with size {}'.format(
tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size)
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size,
tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank()
# Build the data parallel groups.
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
# Build the model parallel groups.
all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank,
tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized'
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size,
(i + 1) * model_parallel_size)
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
......@@ -101,6 +174,20 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'pipeline_model parallel group is not initialized'
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
......@@ -108,41 +195,97 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
global _MPU_WORLD_SIZE
_MPU_WORLD_SIZE = world_size
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, \
'embedding group is not initialized'
return _EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global _MPU_WORLD_SIZE
if _MPU_WORLD_SIZE is not None:
return _MPU_WORLD_SIZE
return torch.distributed.get_world_size(group=get_model_parallel_group())
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def set_model_parallel_rank(rank):
"""Set model parallel rank."""
global _MPU_RANK
_MPU_RANK = rank
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
global _MPU_RANK
if _MPU_RANK is not None:
return _MPU_RANK
return torch.distributed.get_rank(group=get_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zeor
in the model parallel group."""
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage():
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage():
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
......@@ -156,7 +299,9 @@ def get_data_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
......@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
......@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight.model_parallel = True
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
......@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
......@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
self.model_parallel_size)
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
......@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim=0, stride=1)
def forward(self, input_):
if self.model_parallel_size > 1:
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
......@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
if self.model_parallel_size > 1:
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
......@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
......@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.model_parallel = True
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero.
......@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
......@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
......@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
......@@ -15,7 +15,7 @@
import torch
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim
......@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_model_parallel_world_size()==1:
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_model_parallel_group())
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_
......@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
......@@ -45,7 +45,7 @@ def _split(input_):
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_model_parallel_rank()
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
......@@ -54,18 +54,18 @@ def _split(input_):
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_model_parallel_rank()
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
......@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_):
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
......@@ -28,9 +28,9 @@ from megatron import get_args
from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
# Default name for the model parallel rng tracker.
......@@ -45,8 +45,8 @@ def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
......@@ -54,7 +54,7 @@ def init_checkpointed_activations_memory_buffer():
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
......@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_model_parallel_world_size()
start_index = partition_size * get_model_parallel_rank()
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
......@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor,
group=get_model_parallel_group())
group=get_tensor_model_parallel_group())
return gathered
......@@ -215,15 +215,15 @@ def model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + get_model_parallel_rank()
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed.
data_parallel_seed = seed
......@@ -231,15 +231,15 @@ def model_parallel_cuda_manual_seed(seed):
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_model_parallel_rank(),
get_data_parallel_rank(), model_parallel_seed,
torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
model_parallel_seed)
tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
......@@ -268,11 +268,11 @@ class CheckpointFunction(torch.autograd.Function):
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data)
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
......
......@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
logits_parallel = mpu.scatter_to_model_parallel_region(logits)
logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
......@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return loss, identity.weight.grad
def test_cross_entropy(model_parallel_size):
def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
......@@ -89,7 +89,7 @@ def test_cross_entropy(model_parallel_size):
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
mpu.destroy_tensor_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
......@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(model_parallel_size)
model_parallel_size *= 2
test_cross_entropy(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -24,15 +24,15 @@ import sys
sys.path.append("../..")
def test_boradcast_data(model_parallel_size):
def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing boradcast_data with model parallel size {} ...'.
format(model_parallel_size))
print('> testing broadcast_data with model parallel size {} ...'.
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
model_parallel_size = mpu.get_model_parallel_world_size()
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1],
......@@ -48,7 +48,7 @@ def test_boradcast_data(model_parallel_size):
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if mpu.get_model_parallel_rank() != 0:
if mpu.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
......@@ -69,7 +69,7 @@ def test_boradcast_data(model_parallel_size):
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
mpu.destroy_model_parallel()
mpu.destroy_tensor_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
......@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size)
model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -21,15 +21,15 @@ import sys
sys.path.append("../..")
def test_initialize_model_parallel(model_parallel_size):
def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
model_parallel_size))
model_parallel_size_ = min(model_parallel_size,
tensor_model_parallel_size))
tensor_model_parallel_size_ = min(tensor_model_parallel_size,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_)
mpu.initialize_model_parallel(tensor_model_parallel_size_)
assert mpu.model_parallel_is_initialized()
# Checks.
......@@ -38,15 +38,15 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank)
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == mpu.get_tensor_model_parallel_world_size()
assert rank == mpu.get_tensor_model_parallel_rank()
check(mpu.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank)
......@@ -59,20 +59,20 @@ def test_initialize_model_parallel(model_parallel_size):
print('>> passed the test :-)')
def test_get_model_parallel_src_rank(model_parallel_size_):
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_model_parallel_src_rank with size {} ...'.format(
model_parallel_size_))
model_parallel_size = min(model_parallel_size_,
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
tensor_model_parallel_size_))
tensor_model_parallel_size = min(tensor_model_parallel_size_,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
assert mpu.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank
src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
assert mpu.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups
mpu.destroy_model_parallel()
......@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(model_parallel_size)
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_model_parallel_src_rank(model_parallel_size)
model_parallel_size *= 2
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -26,14 +26,14 @@ import sys
sys.path.append("../..")
def test_parallel_embedding(model_parallel_size):
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 17
seq_length = 23
......@@ -80,16 +80,16 @@ def test_parallel_embedding(model_parallel_size):
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size,
1)[mpu.get_model_parallel_rank()]
hidden_size // tensor_model_parallel_size,
1)[mpu.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size,
0)[mpu.get_model_parallel_rank()]
vocab_size // tensor_model_parallel_size,
0)[mpu.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
......@@ -104,19 +104,19 @@ def test_parallel_embedding(model_parallel_size):
print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size):
def test_initialize_affine_weight(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
output_size = output_size_coeff * tensor_model_parallel_size
# ---------------
# Column parallel
......@@ -131,7 +131,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
......@@ -154,7 +154,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
......@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight
def test_column_parallel_linear(model_parallel_size):
def test_column_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
......@@ -219,7 +219,7 @@ def test_column_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
......@@ -250,20 +250,20 @@ def test_column_parallel_linear(model_parallel_size):
print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size):
def test_row_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
......@@ -286,7 +286,7 @@ def test_row_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
......@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
......@@ -352,17 +352,17 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size):
def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size))
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
......@@ -370,14 +370,14 @@ def test_parallel_self_attention(model_parallel_size):
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition,
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
......@@ -389,7 +389,7 @@ def test_parallel_self_attention(model_parallel_size):
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
......@@ -410,11 +410,11 @@ def test_parallel_self_attention(model_parallel_size):
print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
......@@ -440,31 +440,31 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size):
def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size))
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition,
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
......@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size)
model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size)
model_parallel_size *= 2
test_parallel_embedding(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test column-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size)
model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size)
model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test parallel self-attention')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size)
model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_parallel_self_attention(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test parallel transformer')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size)
model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_parallel_transformer_layer(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -21,14 +21,14 @@ import sys
sys.path.append("../..")
def test_set_cuda_rng_state(model_parallel_size):
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
......@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(model_parallel_size):
print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size):
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
......@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(model_parallel_size):
print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size):
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank())
mpu.get_tensor_model_parallel_rank())
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
......@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size)
model_parallel_size *= 2
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size)
model_parallel_size *= 2
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
This diff is collapsed.
......@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size
args.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
......
This diff is collapsed.
......@@ -27,14 +27,16 @@ from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Optimizer
def reduce_losses(losses):
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat(
averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return reduced_losses
return averaged_losses
def report_memory(name):
......@@ -48,14 +50,15 @@ def report_memory(name):
string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
print_rank_0(string)
if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
......@@ -66,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
......@@ -96,11 +99,11 @@ def get_ltor_masks_and_position_ids(data,
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
......@@ -122,7 +125,7 @@ def get_ltor_masks_and_position_ids(data,
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
......
......@@ -23,9 +23,9 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.training import pretrain
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
......@@ -33,10 +33,25 @@ def model_provider():
print_rank_0('building BERT model ...')
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
return model
......@@ -66,34 +81,51 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model):
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator)
timers('batch generator').stop()
timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
# Forward model. lm_labels
lm_loss_, sop_logits = model(tokens, padding_mask,
tokentype_ids=types,
lm_labels=lm_labels)
if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss
loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss])
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples):
......
......@@ -23,16 +23,28 @@ from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True)
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model
......@@ -69,25 +81,42 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
reduced_loss = reduce_losses([loss])
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]}
return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples):
......
......@@ -25,12 +25,13 @@ from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
return general_ict_model_provider(False, False)
......@@ -72,22 +73,22 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output
def forward_step(data_iterator, model):
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
timers('batch-generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop()
timers('batch-generator').stop()
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1
micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
......@@ -102,11 +103,12 @@ def forward_step(data_iterator, model):
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *topk_accs])
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
return retrieval_loss, stats_dict
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment