Commit 52a5f2f2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP

parent 7abd3e90
......@@ -28,15 +28,15 @@ from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_model_parallel_group
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_inter_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank, set_intra_layer_model_parallel_rank
from .initialize import get_inter_layer_model_parallel_rank, set_inter_layer_model_parallel_rank
from .initialize import is_inter_layer_first_stage, is_inter_layer_last_stage
from .initialize import get_intra_layer_model_parallel_src_rank
from .initialize import get_inter_layer_model_parallel_src_rank
from .initialize import get_intra_layer_model_parallel_world_size, set_intra_layer_model_parallel_world_size
from .initialize import get_inter_layer_model_parallel_world_size, set_inter_layer_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_src_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
......@@ -45,15 +45,15 @@ from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .mappings import copy_to_intra_layer_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import intra_layer_model_parallel_cuda_manual_seed
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .utils import divide
......
......@@ -16,9 +16,9 @@
import torch
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .utils import VocabUtility
......@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=get_intra_layer_model_parallel_group())
group=get_tensor_model_parallel_group())
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_intra_layer_model_parallel_rank()
world_size = get_intra_layer_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size)
......@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_intra_layer_model_parallel_group())
group=get_tensor_model_parallel_group())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
......@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_intra_layer_model_parallel_group())
group=get_tensor_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
......
......@@ -15,9 +15,9 @@
import torch
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_src_rank
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4
......@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_intra_layer_model_parallel_rank() == 0:
if get_tensor_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
......@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_intra_layer_model_parallel_src_rank(),
group=get_intra_layer_model_parallel_group())
torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
......@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data)
# Pack on rank zero.
if get_intra_layer_model_parallel_rank() == 0:
if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
......@@ -101,8 +101,8 @@ def broadcast_data(keys, data, datatype):
dtype=datatype)
# Broadcast
torch.distributed.broadcast(flatten_data, get_intra_layer_model_parallel_src_rank(),
group=get_intra_layer_model_parallel_group())
torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_tensor_model_parallel_group())
# Unpack
output = {}
......
......@@ -28,9 +28,9 @@ try:
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_inter_layer_first_stage
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_tensor_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
......@@ -44,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_intra_layer_model_parallel_rank() == 0)
mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.intra_layer_model_parallel or mp_rank_is_zero, parameters_with_grads))
lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
......@@ -101,7 +101,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_inter_layer_first_stage():
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
......@@ -123,7 +123,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
else:
total_norm = 0
for p in parameters:
if p.intra_layer_model_parallel or (get_intra_layer_model_parallel_rank() == 0):
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs.
......
......@@ -22,10 +22,10 @@ from .utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
_INTRA_LAYER_MODEL_PARALLEL_GROUP = None
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_INTER_LAYER_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and inter-layer) that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
......@@ -33,10 +33,10 @@ _EMBEDDING_GROUP = None
_DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_MPU_INTRA_LAYER_WORLD_SIZE = None
_MPU_INTER_LAYER_WORLD_SIZE = None
_MPU_INTRA_LAYER_RANK = None
_MPU_INTER_LAYER_RANK = None
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
def is_unitialized():
......@@ -44,25 +44,25 @@ def is_unitialized():
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(intra_layer_model_parallel_size_=1,
inter_layer_model_parallel_size_=1):
def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1):
"""
Initialize model data parallel groups.
Arguments:
intra_layer_model_parallel_size: number of GPUs used to parallelize model intra-layer.
inter_layer_model_parallel_size: number of GPUs used to parallelize model inter-layer.
tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model intra-layer, and 4 GPUs to parallelize
the model inter-layer. The present function will
create 8 intra-layer model-parallel groups, 4 inter-layer model-parallel groups
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 intra-layer model-parallel groups:
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 inter-layer model-parallel groups:
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
......@@ -70,22 +70,22 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print('> initializing intra-layer model parallel with size {}'.format(
intra_layer_model_parallel_size_))
print('> initializing inter-layer model parallel with size {}'.format(
inter_layer_model_parallel_size_))
print('> initializing tensor model parallel with size {}'.format(
tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = min(intra_layer_model_parallel_size_, world_size)
inter_layer_model_parallel_size = min(inter_layer_model_parallel_size_, world_size)
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size,
intra_layer_model_parallel_size * inter_layer_model_parallel_size)
data_parallel_size = world_size // (intra_layer_model_parallel_size *
inter_layer_model_parallel_size)
tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_intra_layer_model_parallel_groups = world_size // intra_layer_model_parallel_size
num_inter_layer_model_parallel_groups = world_size // inter_layer_model_parallel_size
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank()
......@@ -95,12 +95,12 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
all_data_parallel_group_ranks = []
for i in range(inter_layer_model_parallel_size):
start_rank = i * num_inter_layer_model_parallel_groups
end_rank = (i + 1) * num_inter_layer_model_parallel_groups
for j in range(intra_layer_model_parallel_size):
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank,
intra_layer_model_parallel_size)
tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
......@@ -117,31 +117,31 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the intra-layer model-parallel groups.
global _INTRA_LAYER_MODEL_PARALLEL_GROUP
assert _INTRA_LAYER_MODEL_PARALLEL_GROUP is None, \
'intra-layer model parallel group is already initialized'
for i in range(num_intra_layer_model_parallel_groups):
ranks = range(i * intra_layer_model_parallel_size,
(i + 1) * intra_layer_model_parallel_size)
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_INTRA_LAYER_MODEL_PARALLEL_GROUP = group
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the inter-layer model-parallel groups and embedding groups
# (first and last rank in each inter-layer model-parallel group).
global _INTER_LAYER_MODEL_PARALLEL_GROUP
assert _INTER_LAYER_MODEL_PARALLEL_GROUP is None, \
'inter-layer model parallel group is already initialized'
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
for i in range(num_inter_layer_model_parallel_groups):
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_inter_layer_model_parallel_groups)
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_INTER_LAYER_MODEL_PARALLEL_GROUP = group
_PIPELINE_MODEL_PARALLEL_GROUP = group
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
......@@ -155,8 +155,8 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _INTRA_LAYER_MODEL_PARALLEL_GROUP is None or \
_INTER_LAYER_MODEL_PARALLEL_GROUP is None or \
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
......@@ -169,18 +169,18 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def get_intra_layer_model_parallel_group():
"""Get the intra-layer model parallel group the caller rank belongs to."""
assert _INTRA_LAYER_MODEL_PARALLEL_GROUP is not None, \
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _INTRA_LAYER_MODEL_PARALLEL_GROUP
return _TENSOR_MODEL_PARALLEL_GROUP
def get_inter_layer_model_parallel_group():
"""Get the inter-layer model parallel group the caller rank belongs to."""
assert _INTER_LAYER_MODEL_PARALLEL_GROUP is not None, \
'inter_layer_model parallel group is not initialized'
return _INTER_LAYER_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'pipeline_model parallel group is not initialized'
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
......@@ -197,87 +197,87 @@ def get_embedding_group():
return _EMBEDDING_GROUP
def set_intra_layer_model_parallel_world_size(world_size):
"""Set the intra-layer model parallel size"""
global _MPU_INTRA_LAYER_WORLD_SIZE
_MPU_INTRA_LAYER_WORLD_SIZE = world_size
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_inter_layer_model_parallel_world_size(world_size):
"""Set the inter-layer model parallel size"""
global _MPU_INTER_LAYER_WORLD_SIZE
_MPU_INTER_LAYER_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_intra_layer_model_parallel_world_size():
"""Return world size for the intra-layer model parallel group."""
global _MPU_INTRA_LAYER_WORLD_SIZE
if _MPU_INTRA_LAYER_WORLD_SIZE is not None:
return _MPU_INTRA_LAYER_WORLD_SIZE
return torch.distributed.get_world_size(group=get_intra_layer_model_parallel_group())
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_inter_layer_model_parallel_world_size():
"""Return world size for the inter-layer model parallel group."""
global _MPU_INTER_LAYER_WORLD_SIZE
if _MPU_INTER_LAYER_WORLD_SIZE is not None:
return _MPU_INTER_LAYER_WORLD_SIZE
return torch.distributed.get_world_size(group=get_inter_layer_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_intra_layer_model_parallel_rank(rank):
"""Set intra-layer model parallel rank."""
global _MPU_INTRA_LAYER_RANK
_MPU_INTRA_LAYER_RANK = rank
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_inter_layer_model_parallel_rank(rank):
"""Set inter-layer model parallel rank."""
global _MPU_INTER_LAYER_RANK
_MPU_INTER_LAYER_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_intra_layer_model_parallel_rank():
"""Return my rank for the intra-layer model parallel group."""
global _MPU_INTRA_LAYER_RANK
if _MPU_INTRA_LAYER_RANK is not None:
return _MPU_INTRA_LAYER_RANK
return torch.distributed.get_rank(group=get_intra_layer_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_inter_layer_model_parallel_rank():
"""Return my rank for the inter-layer model parallel group."""
global _MPU_INTER_LAYER_RANK
if _MPU_INTER_LAYER_RANK is not None:
return _MPU_INTER_LAYER_RANK
return torch.distributed.get_rank(group=get_inter_layer_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_inter_layer_first_stage():
"""Return True if in the first inter-layer model-parallel stage, False otherwise."""
return get_inter_layer_model_parallel_rank() == 0
def is_pipeline_first_stage():
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == 0
def is_inter_layer_last_stage():
"""Return True if in the last inter-layer model-parallel stage, False otherwise."""
return get_inter_layer_model_parallel_rank() == (
get_inter_layer_model_parallel_world_size() - 1)
def is_pipeline_last_stage():
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def get_intra_layer_model_parallel_src_rank():
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank
in the intra-layer model parallel group."""
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_intra_layer_model_parallel_world_size()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_inter_layer_model_parallel_src_rank():
def get_pipeline_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank
in the inter-layer model parallel group."""
in the pipeline model parallel group."""
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
local_world_size = get_inter_layer_model_parallel_world_size()
local_world_size = get_pipeline_model_parallel_world_size()
return global_rank % (global_world_size // local_world_size)
......@@ -293,9 +293,9 @@ def get_data_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none."""
global _INTRA_LAYER_MODEL_PARALLEL_GROUP
_INTRA_LAYER_MODEL_PARALLEL_GROUP = None
global _INTER_LAYER_MODEL_PARALLEL_GROUP
_INTER_LAYER_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
......@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
from .mappings import copy_to_intra_layer_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.intra_layer_model_parallel = True
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
......@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight.intra_layer_model_parallel = True
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
......@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
world_size = get_intra_layer_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
......@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.intra_layer_model_parallel_size = get_intra_layer_model_parallel_world_size()
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_intra_layer_model_parallel_rank(),
self.intra_layer_model_parallel_size)
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
......@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim=0, stride=1)
def forward(self, input_):
if self.intra_layer_model_parallel_size > 1:
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
......@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
if self.intra_layer_model_parallel_size > 1:
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_intra_layer_model_parallel_region(output_parallel)
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
......@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_intra_layer_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
......@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.intra_layer_model_parallel = True
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero.
......@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_intra_layer_model_parallel_region(input_)
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_intra_layer_model_parallel_region(output_parallel)
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
......@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_intra_layer_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
......@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_intra_layer_model_parallel_region(input_)
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_intra_layer_model_parallel_region(output_parallel)
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
......@@ -15,7 +15,7 @@
import torch
from .initialize import get_intra_layer_model_parallel_group, get_intra_layer_model_parallel_world_size, get_intra_layer_model_parallel_rank
from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim
......@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_intra_layer_model_parallel_world_size()==1:
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_intra_layer_model_parallel_group())
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_
......@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_intra_layer_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
......@@ -45,7 +45,7 @@ def _split(input_):
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_intra_layer_model_parallel_rank()
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
......@@ -54,18 +54,18 @@ def _split(input_):
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_intra_layer_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_intra_layer_model_parallel_rank()
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_intra_layer_model_parallel_group())
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
......@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions.
# -----------------
def copy_to_intra_layer_model_parallel_region(input_):
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_intra_layer_model_parallel_region(input_):
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_intra_layer_model_parallel_region(input_):
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_intra_layer_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
......@@ -28,13 +28,13 @@ from megatron import get_args
from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'intra-layer-model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'tensor-model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
......@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_intra_layer_model_parallel_world_size()
start_index = partition_size * get_intra_layer_model_parallel_rank()
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_intra_layer_model_parallel_world_size()
world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
......@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor,
group=get_intra_layer_model_parallel_group())
group=get_tensor_model_parallel_group())
return gathered
......@@ -204,7 +204,7 @@ def get_cuda_rng_tracker():
return _CUDA_RNG_STATE_TRACKER
def intra_layer_model_parallel_cuda_manual_seed(seed):
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
......@@ -215,15 +215,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-intra-layer-model-parallel regions.
intra-layer-model-parallel state: This state is different among a set of model
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
intra_layer_model_parallel_seed = offset + get_intra_layer_model_parallel_rank()
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original sedd.
data_parallel_seed = seed
......@@ -231,15 +231,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_intra_layer_model_parallel_rank(),
get_data_parallel_rank(), intra_layer_model_parallel_seed,
torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
intra_layer_model_parallel_seed)
tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
......
......@@ -36,7 +36,7 @@ def set_random_seed(seed):
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
mpu.intra_layer_model_parallel_cuda_manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
......
......@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
logits_parallel = mpu.scatter_to_intra_layer_model_parallel_region(logits)
logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
......@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return loss, identity.weight.grad
def test_cross_entropy(intra_layer_model_parallel_size):
def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(intra_layer_model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * intra_layer_model_parallel_size
vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
......@@ -89,7 +89,7 @@ def test_cross_entropy(intra_layer_model_parallel_size):
assert error < 1.0e-6
# Reset groups
mpu.destroy_intra_layer_model_parallel()
mpu.destroy_tensor_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
......@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_cross_entropy(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -24,15 +24,15 @@ import sys
sys.path.append("../..")
def test_broadcast_data(intra_layer_model_parallel_size):
def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing broadcast_data with model parallel size {} ...'.
format(intra_layer_model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1],
......@@ -48,7 +48,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if mpu.get_intra_layer_model_parallel_rank() != 0:
if mpu.get_tensor_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
......@@ -69,7 +69,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
mpu.destroy_intra_layer_model_parallel()
mpu.destroy_tensor_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
......@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_broadcast_data(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -21,15 +21,15 @@ import sys
sys.path.append("../..")
def test_initialize_model_parallel(intra_layer_model_parallel_size):
def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
intra_layer_model_parallel_size))
intra_layer_model_parallel_size_ = min(intra_layer_model_parallel_size,
tensor_model_parallel_size))
tensor_model_parallel_size_ = min(tensor_model_parallel_size,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(intra_layer_model_parallel_size_)
mpu.initialize_model_parallel(tensor_model_parallel_size_)
assert mpu.model_parallel_is_initialized()
# Checks.
......@@ -38,15 +38,15 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = intra_layer_model_parallel_size_
rank = torch.distributed.get_rank() % intra_layer_model_parallel_size_
assert world_size == mpu.get_intra_layer_model_parallel_world_size()
assert rank == mpu.get_intra_layer_model_parallel_rank()
check(mpu.get_intra_layer_model_parallel_group(), world_size, rank)
world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == mpu.get_tensor_model_parallel_world_size()
assert rank == mpu.get_tensor_model_parallel_rank()
check(mpu.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // intra_layer_model_parallel_size_
rank = torch.distributed.get_rank() // intra_layer_model_parallel_size
world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank)
......@@ -59,20 +59,20 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
print('>> passed the test :-)')
def test_get_intra_layer_model_parallel_src_rank(intra_layer_model_parallel_size_):
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_intra_layer_model_parallel_src_rank with size {} ...'.format(
intra_layer_model_parallel_size_))
intra_layer_model_parallel_size = min(intra_layer_model_parallel_size_,
print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
tensor_model_parallel_size_))
tensor_model_parallel_size = min(tensor_model_parallel_size_,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
assert mpu.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - mpu.get_intra_layer_model_parallel_rank()
assert mpu.get_intra_layer_model_parallel_src_rank() == src_rank
src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
assert mpu.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups
mpu.destroy_model_parallel()
......@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(intra_layer_model_parallel_size)
test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank')
test_get_intra_layer_model_parallel_src_rank(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -26,14 +26,14 @@ import sys
sys.path.append("../..")
def test_parallel_embedding(intra_layer_model_parallel_size):
def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(intra_layer_model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 17
seq_length = 23
......@@ -80,16 +80,16 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // intra_layer_model_parallel_size,
1)[mpu.get_intra_layer_model_parallel_rank()]
hidden_size // tensor_model_parallel_size,
1)[mpu.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // intra_layer_model_parallel_size,
0)[mpu.get_intra_layer_model_parallel_rank()]
vocab_size // tensor_model_parallel_size,
0)[mpu.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
......@@ -104,19 +104,19 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
print('>> passed the test :-)')
def test_initialize_affine_weight(intra_layer_model_parallel_size):
def test_initialize_affine_weight(tensor_model_parallel_size):
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(intra_layer_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * intra_layer_model_parallel_size
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * intra_layer_model_parallel_size
output_size = output_size_coeff * tensor_model_parallel_size
# ---------------
# Column parallel
......@@ -131,7 +131,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_intra_layer_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
......@@ -154,7 +154,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_intra_layer_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
......@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight
def test_column_parallel_linear(intra_layer_model_parallel_size):
def test_column_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(intra_layer_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * intra_layer_model_parallel_size
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * intra_layer_model_parallel_size
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
......@@ -219,7 +219,7 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_intra_layer_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
......@@ -250,20 +250,20 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
print(' >> passed the test :-)')
def test_row_parallel_linear(intra_layer_model_parallel_size):
def test_row_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(intra_layer_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
'size: {}'.format(tensor_model_parallel_size))
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * intra_layer_model_parallel_size
input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * intra_layer_model_parallel_size
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
# Network
......@@ -286,7 +286,7 @@ def test_row_parallel_linear(intra_layer_model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_intra_layer_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
......@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return self.weight
def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_partition,
def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
......@@ -352,17 +352,17 @@ def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_p
# Backward
loss.backward()
rank = mpu.get_intra_layer_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, intra_layer_model_parallel_size, loss, \
return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(intra_layer_model_parallel_size):
def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(intra_layer_model_parallel_size))
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
......@@ -370,14 +370,14 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, intra_layer_model_parallel_size_1, loss_1, \
rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, intra_layer_model_parallel_size, loss, \
rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
intra_layer_model_parallel_size, num_att_heads_per_partition,
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
......@@ -389,7 +389,7 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // intra_layer_model_parallel_size, 0)[rank::intra_layer_model_parallel_size]
hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
......@@ -410,11 +410,11 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
print(' >> passed the test :-)')
def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_partition,
def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
......@@ -440,31 +440,31 @@ def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_part
# Backward
loss.backward()
rank = mpu.get_intra_layer_model_parallel_rank()
rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, intra_layer_model_parallel_size, loss, \
return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(intra_layer_model_parallel_size):
def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(intra_layer_model_parallel_size))
'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, intra_layer_model_parallel_size_1, loss_1, \
rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, intra_layer_model_parallel_size, loss, \
rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
intra_layer_model_parallel_size, num_att_heads_per_partition,
tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
......@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight')
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_initialize_affine_weight(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test parallel embedding')
test_parallel_embedding(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_parallel_embedding(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test column-parallel linear')
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_column_parallel_linear(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test row-parallel linear')
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_row_parallel_linear(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test parallel self-attention')
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_parallel_self_attention(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_parallel_self_attention(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
print_separator('test parallel transformer')
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_parallel_transformer_layer(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
test_parallel_transformer_layer(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -21,14 +21,14 @@ import sys
sys.path.append("../..")
def test_set_cuda_rng_state(intra_layer_model_parallel_size):
def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(intra_layer_model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
size = 123
seed = 1234
......@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(intra_layer_model_parallel_size):
print('>> passed the test :-)')
def test_cuda_rng_tracker(intra_layer_model_parallel_size):
def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(intra_layer_model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
......@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(intra_layer_model_parallel_size):
print('>> passed the test :-)')
def test_intra_layer_model_parallel_cuda_manual_seed(intra_layer_model_parallel_size):
def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'.
format(intra_layer_model_parallel_size))
format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.initialize_model_parallel(tensor_model_parallel_size)
tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
mpu.intra_layer_model_parallel_cuda_manual_seed(12345)
mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_intra_layer_model_parallel_rank())
mpu.get_tensor_model_parallel_rank())
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
......@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_set_cuda_rng_state(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_cuda_rng_tracker(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
tensor_model_parallel_size = 1
while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_intra_layer_model_parallel_cuda_manual_seed(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
tensor_model_parallel_size *= 2
......@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
......@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model):
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
......@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model):
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
......@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
......@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model):
raw_text = None
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1
......@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24):
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text:
......@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
......@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_intra_layer_model_parallel_rank() == 0 and \
if mpu.get_tensor_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
......@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
......@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24):
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
......@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens):
context_length_tensor = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor,
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
......
......@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.intra_layer_model_parallel_size
args.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
......
......@@ -124,10 +124,10 @@ def get_model(model_provider_func):
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (intra-layer, inter-layer) '
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_intra_layer_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank(),
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
......@@ -166,8 +166,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'intra_layer_model_parallel'):
param.intra_layer_model_parallel = False
if not hasattr(param, 'tensor_model_parallel'):
param.tensor_model_parallel = False
# Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
......@@ -260,7 +260,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_inter_layer_model_parallel_group())
group=mpu.get_pipeline_model_parallel_group())
return tensor_recv_prev, tensor_recv_next
......@@ -304,7 +304,7 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad()
# Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline = args.inter_layer_model_parallel_size \
num_microbatches_to_pipeline = args.pipeline_model_parallel_size \
if args.use_pipelining else 1
input_tensors = []
......@@ -313,7 +313,7 @@ def train_step(forward_step_func, data_iterator,
# Run forward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
if not mpu.is_inter_layer_first_stage():
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator,
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
......@@ -346,7 +346,7 @@ def train_step(forward_step_func, data_iterator,
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
output_grad_tensor = None
else:
_, output_grad_tensor = communicate(
......@@ -362,7 +362,7 @@ def train_step(forward_step_func, data_iterator,
backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
timers('backward').stop()
if not mpu.is_inter_layer_first_stage():
if not mpu.is_pipeline_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
......@@ -383,8 +383,8 @@ def train_step(forward_step_func, data_iterator,
timers('backward-master-grad').stop()
# All-reduce across first and last stages.
if (mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage()) and \
args.inter_layer_model_parallel_size > 1:
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.pipeline_model_parallel_size > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
......@@ -421,7 +421,7 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
......@@ -604,7 +604,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
if not mpu.is_inter_layer_first_stage():
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
......@@ -616,7 +616,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Forward evaluation.
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
_, loss_dict = output_tensor
# Reduce across processes.
for key in loss_dict:
......@@ -671,7 +671,7 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...')
# Data loader only on rank 0 of each model parallel group.
if mpu.get_intra_layer_model_parallel_rank() == 0:
if mpu.get_tensor_model_parallel_rank() == 0:
# Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
......@@ -709,8 +709,8 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
......
......@@ -58,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, intra-layer-model-parallel, min, max, norm\n'
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
......@@ -69,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.intra_layer_model_parallel))
iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
......
......@@ -34,12 +34,12 @@ def model_provider():
print_rank_0('building BERT model ...')
args = get_args()
if args.inter_layer_model_parallel_size > 1:
if args.pipeline_model_parallel_size > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_inter_layer_last_stage():
elif mpu.is_pipeline_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
......@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers('batch generator').stop()
# Forward pass through the model.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_inter_layer_last_stage():
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
......
......@@ -33,11 +33,11 @@ def model_provider():
print_rank_0('building GPT2 model ...')
args = get_args()
if args.inter_layer_model_parallel_size > 1:
if args.pipeline_model_parallel_size > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_inter_layer_last_stage():
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
......@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers('batch generator').stop()
# Forward pass through the model.
if mpu.is_inter_layer_first_stage():
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_inter_layer_last_stage():
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_inter_layer_last_stage():
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment