Commit 52a5f2f2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP

parent 7abd3e90
...@@ -28,15 +28,15 @@ from .initialize import get_data_parallel_rank ...@@ -28,15 +28,15 @@ from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group from .initialize import get_embedding_group
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_intra_layer_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_inter_layer_model_parallel_group from .initialize import get_pipeline_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank, set_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_inter_layer_model_parallel_rank, set_inter_layer_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_inter_layer_first_stage, is_inter_layer_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_intra_layer_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_inter_layer_model_parallel_src_rank from .initialize import get_pipeline_model_parallel_src_rank
from .initialize import get_intra_layer_model_parallel_world_size, set_intra_layer_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_inter_layer_model_parallel_world_size, set_inter_layer_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
...@@ -45,15 +45,15 @@ from .layers import ColumnParallelLinear ...@@ -45,15 +45,15 @@ from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .mappings import copy_to_intra_layer_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer from .random import init_checkpointed_activations_memory_buffer
from .random import intra_layer_model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer from .random import reset_checkpointed_activations_memory_buffer
from .utils import divide from .utils import divide
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
import torch import torch
from .initialize import get_intra_layer_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .utils import VocabUtility from .utils import VocabUtility
...@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max, torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_intra_layer_model_parallel_group()) group=get_tensor_model_parallel_group())
# Subtract the maximum value. # Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies # Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1] partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_intra_layer_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range( vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size) partition_vocab_size, rank, world_size)
...@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# All reduce is needed to get the chunks from other GPUs. # All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_intra_layer_model_parallel_group()) group=get_tensor_model_parallel_group())
# Sum of exponential of logits along vocab dimension across all GPUs. # Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits exp_logits = vocab_parallel_logits
...@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
sum_exp_logits = exp_logits.sum(dim=-1) sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_intra_layer_model_parallel_group()) group=get_tensor_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import torch import torch
from .initialize import get_intra_layer_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 4
...@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes = [0 for _ in range(max_dim) for _ in keys] sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero. # Pack the sizes on rank zero.
if get_intra_layer_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
offset = 0 offset = 0
for key in keys: for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
...@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast. # Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes) sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_intra_layer_model_parallel_src_rank(), torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
group=get_intra_layer_model_parallel_group()) group=get_tensor_model_parallel_group())
# Move back to cpu and unpack. # Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu() sizes_cpu = sizes_cuda.cpu()
...@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype): ...@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data) data)
# Pack on rank zero. # Pack on rank zero.
if get_intra_layer_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type. # Check that all keys have the same data type.
_check_data_types(keys, data, datatype) _check_data_types(keys, data, datatype)
# Flatten the data associated with the keys # Flatten the data associated with the keys
...@@ -101,8 +101,8 @@ def broadcast_data(keys, data, datatype): ...@@ -101,8 +101,8 @@ def broadcast_data(keys, data, datatype):
dtype=datatype) dtype=datatype)
# Broadcast # Broadcast
torch.distributed.broadcast(flatten_data, get_intra_layer_model_parallel_src_rank(), torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_intra_layer_model_parallel_group()) group=get_tensor_model_parallel_group())
# Unpack # Unpack
output = {} output = {}
......
...@@ -28,9 +28,9 @@ try: ...@@ -28,9 +28,9 @@ try:
except Exception as e: except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.') print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_inter_layer_first_stage from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
def l2_grad_clipper(parameters, max_norm): def l2_grad_clipper(parameters, max_norm):
...@@ -44,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -44,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads = list(filter( parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters)) lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations. # Filter parameters for norm calculations.
mp_rank_is_zero = (get_intra_layer_model_parallel_rank() == 0) mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter( parameters_for_norm = list(filter(
lambda p: p.intra_layer_model_parallel or mp_rank_is_zero, parameters_with_grads)) lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm. # Calculate L2 norm.
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
...@@ -101,7 +101,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): ...@@ -101,7 +101,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
# Count embedding layer only once (in first stage). # Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage. # Don't count the weights a second time in the last stage.
if "embedding" not in n or \ if "embedding" not in n or \
is_inter_layer_first_stage(): is_pipeline_first_stage():
filtered_parameters.append(p) filtered_parameters.append(p)
parameters = filtered_parameters parameters = filtered_parameters
else: else:
...@@ -123,7 +123,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): ...@@ -123,7 +123,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
else: else:
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
if p.intra_layer_model_parallel or (get_intra_layer_model_parallel_rank() == 0): if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = p.grad.data.norm(norm_type) param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
......
...@@ -22,10 +22,10 @@ from .utils import ensure_divisibility ...@@ -22,10 +22,10 @@ from .utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_INTRA_LAYER_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to. # Inter-layer model parallel group that the current rank belongs to.
_INTER_LAYER_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and inter-layer) that the current rank belongs to. # Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
# Embedding group. # Embedding group.
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
...@@ -33,10 +33,10 @@ _EMBEDDING_GROUP = None ...@@ -33,10 +33,10 @@ _EMBEDDING_GROUP = None
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_INTRA_LAYER_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_INTER_LAYER_WORLD_SIZE = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_INTRA_LAYER_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_INTER_LAYER_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
def is_unitialized(): def is_unitialized():
...@@ -44,25 +44,25 @@ def is_unitialized(): ...@@ -44,25 +44,25 @@ def is_unitialized():
return _DATA_PARALLEL_GROUP is None return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(intra_layer_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
inter_layer_model_parallel_size_=1): pipeline_model_parallel_size_=1):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
intra_layer_model_parallel_size: number of GPUs used to parallelize model intra-layer. tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
inter_layer_model_parallel_size: number of GPUs used to parallelize model inter-layer. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model intra-layer, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model inter-layer. The present function will the model pipeline. The present function will
create 8 intra-layer model-parallel groups, 4 inter-layer model-parallel groups create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as: and 8 data-parallel groups as:
8 data_parallel groups: 8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 intra-layer model-parallel groups: 8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 inter-layer model-parallel groups: 4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes are on the same DGX box. For example if we are using 2 DGX-1 boxes
...@@ -70,22 +70,22 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1, ...@@ -70,22 +70,22 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> initializing intra-layer model parallel with size {}'.format( print('> initializing tensor model parallel with size {}'.format(
intra_layer_model_parallel_size_)) tensor_model_parallel_size_))
print('> initializing inter-layer model parallel with size {}'.format( print('> initializing pipeline model parallel with size {}'.format(
inter_layer_model_parallel_size_)) pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = min(intra_layer_model_parallel_size_, world_size) tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
inter_layer_model_parallel_size = min(inter_layer_model_parallel_size_, world_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size, ensure_divisibility(world_size,
intra_layer_model_parallel_size * inter_layer_model_parallel_size) tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (intra_layer_model_parallel_size * data_parallel_size = world_size // (tensor_model_parallel_size *
inter_layer_model_parallel_size) pipeline_model_parallel_size)
num_intra_layer_model_parallel_groups = world_size // intra_layer_model_parallel_size num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_inter_layer_model_parallel_groups = world_size // inter_layer_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
...@@ -95,12 +95,12 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1, ...@@ -95,12 +95,12 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized' 'data parallel group is already initialized'
all_data_parallel_group_ranks = [] all_data_parallel_group_ranks = []
for i in range(inter_layer_model_parallel_size): for i in range(pipeline_model_parallel_size):
start_rank = i * num_inter_layer_model_parallel_groups start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_inter_layer_model_parallel_groups end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(intra_layer_model_parallel_size): for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, ranks = range(start_rank + j, end_rank,
intra_layer_model_parallel_size) tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks)) all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
...@@ -117,31 +117,31 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1, ...@@ -117,31 +117,31 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
if rank in ranks: if rank in ranks:
_MODEL_PARALLEL_GROUP = group _MODEL_PARALLEL_GROUP = group
# Build the intra-layer model-parallel groups. # Build the tensor model-parallel groups.
global _INTRA_LAYER_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
assert _INTRA_LAYER_MODEL_PARALLEL_GROUP is None, \ assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'intra-layer model parallel group is already initialized' 'tensor model parallel group is already initialized'
for i in range(num_intra_layer_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = range(i * intra_layer_model_parallel_size, ranks = range(i * tensor_model_parallel_size,
(i + 1) * intra_layer_model_parallel_size) (i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_INTRA_LAYER_MODEL_PARALLEL_GROUP = group _TENSOR_MODEL_PARALLEL_GROUP = group
# Build the inter-layer model-parallel groups and embedding groups # Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each inter-layer model-parallel group). # (first and last rank in each pipeline model-parallel group).
global _INTER_LAYER_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
assert _INTER_LAYER_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'inter-layer model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized' 'embedding group is already initialized'
for i in range(num_inter_layer_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, ranks = range(i, world_size,
num_inter_layer_model_parallel_groups) num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_INTER_LAYER_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
# Setup embedding group (to exchange gradients between # Setup embedding group (to exchange gradients between
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
...@@ -155,8 +155,8 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1, ...@@ -155,8 +155,8 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _INTRA_LAYER_MODEL_PARALLEL_GROUP is None or \ if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_INTER_LAYER_MODEL_PARALLEL_GROUP is None or \ _PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None: _DATA_PARALLEL_GROUP is None:
return False return False
return True return True
...@@ -169,18 +169,18 @@ def get_model_parallel_group(): ...@@ -169,18 +169,18 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP return _MODEL_PARALLEL_GROUP
def get_intra_layer_model_parallel_group(): def get_tensor_model_parallel_group():
"""Get the intra-layer model parallel group the caller rank belongs to.""" """Get the tensor model parallel group the caller rank belongs to."""
assert _INTRA_LAYER_MODEL_PARALLEL_GROUP is not None, \ assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized' 'intra_layer_model parallel group is not initialized'
return _INTRA_LAYER_MODEL_PARALLEL_GROUP return _TENSOR_MODEL_PARALLEL_GROUP
def get_inter_layer_model_parallel_group(): def get_pipeline_model_parallel_group():
"""Get the inter-layer model parallel group the caller rank belongs to.""" """Get the pipeline model parallel group the caller rank belongs to."""
assert _INTER_LAYER_MODEL_PARALLEL_GROUP is not None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'inter_layer_model parallel group is not initialized' 'pipeline_model parallel group is not initialized'
return _INTER_LAYER_MODEL_PARALLEL_GROUP return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group(): def get_data_parallel_group():
...@@ -197,87 +197,87 @@ def get_embedding_group(): ...@@ -197,87 +197,87 @@ def get_embedding_group():
return _EMBEDDING_GROUP return _EMBEDDING_GROUP
def set_intra_layer_model_parallel_world_size(world_size): def set_tensor_model_parallel_world_size(world_size):
"""Set the intra-layer model parallel size""" """Set the tensor model parallel size"""
global _MPU_INTRA_LAYER_WORLD_SIZE global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_INTRA_LAYER_WORLD_SIZE = world_size _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_inter_layer_model_parallel_world_size(world_size): def set_pipeline_model_parallel_world_size(world_size):
"""Set the inter-layer model parallel size""" """Set the pipeline model parallel size"""
global _MPU_INTER_LAYER_WORLD_SIZE global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_INTER_LAYER_WORLD_SIZE = world_size _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_intra_layer_model_parallel_world_size(): def get_tensor_model_parallel_world_size():
"""Return world size for the intra-layer model parallel group.""" """Return world size for the tensor model parallel group."""
global _MPU_INTRA_LAYER_WORLD_SIZE global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_INTRA_LAYER_WORLD_SIZE is not None: if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_INTRA_LAYER_WORLD_SIZE return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_intra_layer_model_parallel_group()) return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_inter_layer_model_parallel_world_size(): def get_pipeline_model_parallel_world_size():
"""Return world size for the inter-layer model parallel group.""" """Return world size for the pipeline model parallel group."""
global _MPU_INTER_LAYER_WORLD_SIZE global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_INTER_LAYER_WORLD_SIZE is not None: if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_INTER_LAYER_WORLD_SIZE return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_inter_layer_model_parallel_group()) return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_intra_layer_model_parallel_rank(rank): def set_tensor_model_parallel_rank(rank):
"""Set intra-layer model parallel rank.""" """Set tensor model parallel rank."""
global _MPU_INTRA_LAYER_RANK global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_INTRA_LAYER_RANK = rank _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_inter_layer_model_parallel_rank(rank): def set_pipeline_model_parallel_rank(rank):
"""Set inter-layer model parallel rank.""" """Set pipeline model parallel rank."""
global _MPU_INTER_LAYER_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_INTER_LAYER_RANK = rank _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_intra_layer_model_parallel_rank(): def get_tensor_model_parallel_rank():
"""Return my rank for the intra-layer model parallel group.""" """Return my rank for the tensor model parallel group."""
global _MPU_INTRA_LAYER_RANK global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_INTRA_LAYER_RANK is not None: if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_INTRA_LAYER_RANK return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_intra_layer_model_parallel_group()) return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_inter_layer_model_parallel_rank(): def get_pipeline_model_parallel_rank():
"""Return my rank for the inter-layer model parallel group.""" """Return my rank for the pipeline model parallel group."""
global _MPU_INTER_LAYER_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_INTER_LAYER_RANK is not None: if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_INTER_LAYER_RANK return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_inter_layer_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_inter_layer_first_stage(): def is_pipeline_first_stage():
"""Return True if in the first inter-layer model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
return get_inter_layer_model_parallel_rank() == 0 return get_pipeline_model_parallel_rank() == 0
def is_inter_layer_last_stage(): def is_pipeline_last_stage():
"""Return True if in the last inter-layer model-parallel stage, False otherwise.""" """Return True if in the last pipeline model-parallel stage, False otherwise."""
return get_inter_layer_model_parallel_rank() == ( return get_pipeline_model_parallel_rank() == (
get_inter_layer_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def get_intra_layer_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank """Calculate the global rank corresponding to a local rank
in the intra-layer model parallel group.""" in the tensor model parallel group."""
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
local_world_size = get_intra_layer_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_inter_layer_model_parallel_src_rank(): def get_pipeline_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank """Calculate the global rank corresponding to a local rank
in the inter-layer model parallel group.""" in the pipeline model parallel group."""
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size() global_world_size = torch.distributed.get_world_size()
local_world_size = get_inter_layer_model_parallel_world_size() local_world_size = get_pipeline_model_parallel_world_size()
return global_rank % (global_world_size // local_world_size) return global_rank % (global_world_size // local_world_size)
...@@ -293,9 +293,9 @@ def get_data_parallel_rank(): ...@@ -293,9 +293,9 @@ def get_data_parallel_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _INTRA_LAYER_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
_INTRA_LAYER_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _INTER_LAYER_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
_INTER_LAYER_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
...@@ -35,12 +35,12 @@ except Exception as e: ...@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!') 'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm from torch.nn import LayerNorm
from .initialize import get_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_intra_layer_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method, ...@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1): partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
weight.intra_layer_model_parallel = True weight.tensor_model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
...@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
the relevant chunk.""" the relevant chunk."""
weight.intra_layer_model_parallel = True weight.tensor_model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
...@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_model_parallel_rank()
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
with torch.no_grad(): with torch.no_grad():
...@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False self.scale_grad_by_freq = False
self.sparse = False self.sparse = False
self._weight = None self._weight = None
self.intra_layer_model_parallel_size = get_intra_layer_model_parallel_world_size() self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \ self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size( VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_intra_layer_model_parallel_rank(), self.num_embeddings, get_tensor_model_parallel_rank(),
self.intra_layer_model_parallel_size) self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
...@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim=0, stride=1) partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.intra_layer_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
input_mask = (input_ < self.vocab_start_index) | \ input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index) (input_ >= self.vocab_end_index)
...@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq, self.norm_type, self.scale_grad_by_freq,
self.sparse) self.sparse)
# Mask the output embedding. # Mask the output embedding.
if self.intra_layer_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0 output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_intra_layer_model_parallel_region(output_parallel) output = reduce_from_tensor_model_parallel_region(output_parallel)
return output return output
...@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
...@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.intra_layer_model_parallel = True self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = stride self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
...@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_intra_layer_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_intra_layer_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
...@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
...@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_intra_layer_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_intra_layer_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
output_bias = None output_bias = None
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import torch import torch
from .initialize import get_intra_layer_model_parallel_group, get_intra_layer_model_parallel_world_size, get_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -23,11 +23,11 @@ def _reduce(input_): ...@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_intra_layer_model_parallel_world_size()==1: if get_tensor_model_parallel_world_size()==1:
return input_ return input_
# All-reduce. # All-reduce.
torch.distributed.all_reduce(input_, group=get_intra_layer_model_parallel_group()) torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_ return input_
...@@ -36,7 +36,7 @@ def _split(input_): ...@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size==1:
return input_ return input_
...@@ -45,7 +45,7 @@ def _split(input_): ...@@ -45,7 +45,7 @@ def _split(input_):
input_list = split_tensor_along_last_dim(input_, world_size) input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = get_intra_layer_model_parallel_rank() rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous() output = input_list[rank].contiguous()
return output return output
...@@ -54,18 +54,18 @@ def _split(input_): ...@@ -54,18 +54,18 @@ def _split(input_):
def _gather(input_): def _gather(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size==1:
return input_ return input_
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = get_intra_layer_model_parallel_rank() rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_intra_layer_model_parallel_group()) torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor. # Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous() output = torch.cat(tensor_list, dim=last_dim).contiguous()
...@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions. # Helper functions.
# ----------------- # -----------------
def copy_to_intra_layer_model_parallel_region(input_): def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_) return _CopyToModelParallelRegion.apply(input_)
def reduce_from_intra_layer_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_intra_layer_model_parallel_region(input_): def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_from_intra_layer_model_parallel_region(input_): def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
...@@ -28,13 +28,13 @@ from megatron import get_args ...@@ -28,13 +28,13 @@ from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_intra_layer_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'intra-layer-model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'tensor-model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states. # Whether apply model parallelsim to checkpointed hidden states.
...@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) data = tensor.view(-1)
partition_size = torch.numel(data) // get_intra_layer_model_parallel_world_size() partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_intra_layer_model_parallel_rank() start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] return data[start_index:end_index]
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_intra_layer_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor) numel = torch.numel(tensor)
numel_gathered = world_size * numel numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
...@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor): ...@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, torch.distributed.all_gather(chunks, tensor,
group=get_intra_layer_model_parallel_group()) group=get_tensor_model_parallel_group())
return gathered return gathered
...@@ -204,7 +204,7 @@ def get_cuda_rng_tracker(): ...@@ -204,7 +204,7 @@ def get_cuda_rng_tracker():
return _CUDA_RNG_STATE_TRACKER return _CUDA_RNG_STATE_TRACKER
def intra_layer_model_parallel_cuda_manual_seed(seed): def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed. """Initialize model parallel cuda seed.
This function should be called after the model parallel is This function should be called after the model parallel is
...@@ -215,15 +215,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed): ...@@ -215,15 +215,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across set of model parallel GPUs but different across
different model paralle groups. This is used for different model paralle groups. This is used for
example for dropout in the non-intra-layer-model-parallel regions. example for dropout in the non-tensor-model-parallel regions.
intra-layer-model-parallel state: This state is different among a set of model tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in groups. This is used for example for dropout in
model parallel regions. model parallel regions.
""" """
# 2718 is just for fun and any POSITIVE value will work. # 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718 offset = seed + 2718
intra_layer_model_parallel_seed = offset + get_intra_layer_model_parallel_rank() tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original sedd. # Data parallel gets the original sedd.
data_parallel_seed = seed data_parallel_seed = seed
...@@ -231,15 +231,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed): ...@@ -231,15 +231,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
print('> initializing model parallel cuda seeds on global rank {}, ' print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with ' 'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format( 'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_intra_layer_model_parallel_rank(), torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), intra_layer_model_parallel_seed, get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True) data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state. # and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
intra_layer_model_parallel_seed) tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
......
...@@ -36,7 +36,7 @@ def set_random_seed(seed): ...@@ -36,7 +36,7 @@ def set_random_seed(seed):
random.seed(seed) random.seed(seed)
numpy.random.seed(seed) numpy.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
mpu.intra_layer_model_parallel_cuda_manual_seed(seed) mpu.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'): def initialize_distributed(backend='nccl'):
......
...@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, ...@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity = IdentityLayer((batch_size, seq_length, vocab_size), identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda() scale=logits_scale).cuda()
logits = identity() logits = identity()
logits_parallel = mpu.scatter_to_intra_layer_model_parallel_region(logits) logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor( target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size) size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
...@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, ...@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return loss, identity.weight.grad return loss, identity.weight.grad
def test_cross_entropy(intra_layer_model_parallel_size): def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'. print('> testing cross entropy with model parallel size {} ...'.
format(intra_layer_model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 13 batch_size = 13
seq_length = 17 seq_length = 17
vocab_size_per_partition = 11 vocab_size_per_partition = 11
logits_scale = 1000.0 logits_scale = 1000.0
vocab_size = vocab_size_per_partition * intra_layer_model_parallel_size vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234 seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
...@@ -89,7 +89,7 @@ def test_cross_entropy(intra_layer_model_parallel_size): ...@@ -89,7 +89,7 @@ def test_cross_entropy(intra_layer_model_parallel_size):
assert error < 1.0e-6 assert error < 1.0e-6
# Reset groups # Reset groups
mpu.destroy_intra_layer_model_parallel() mpu.destroy_tensor_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -101,8 +101,8 @@ if __name__ == '__main__': ...@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy') print_separator('test cross entropy')
test_cross_entropy(intra_layer_model_parallel_size) test_cross_entropy(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -24,15 +24,15 @@ import sys ...@@ -24,15 +24,15 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_broadcast_data(intra_layer_model_parallel_size): def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing broadcast_data with model parallel size {} ...'. print('> testing broadcast_data with model parallel size {} ...'.
format(intra_layer_model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank()) torch.manual_seed(1234 + mpu.get_data_parallel_rank())
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
key_size_t = {'key1': [7, 11], key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1], 'key2': [8, 2, 1],
...@@ -48,7 +48,7 @@ def test_broadcast_data(intra_layer_model_parallel_size): ...@@ -48,7 +48,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
data_t[key] = data[key].clone() data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone() data_t['keyX'] = data['keyX'].clone()
if mpu.get_intra_layer_model_parallel_rank() != 0: if mpu.get_tensor_model_parallel_rank() != 0:
data = None data = None
data_utils._check_data_types(keys, data_t, torch.int64) data_utils._check_data_types(keys, data_t, torch.int64)
...@@ -69,7 +69,7 @@ def test_broadcast_data(intra_layer_model_parallel_size): ...@@ -69,7 +69,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
assert data_b[key].sub(tensor).abs().max() == 0 assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups # Reset groups
mpu.destroy_intra_layer_model_parallel() mpu.destroy_tensor_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -81,8 +81,8 @@ if __name__ == '__main__': ...@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test test broadcast data') print_separator('test test broadcast data')
test_broadcast_data(intra_layer_model_parallel_size) test_broadcast_data(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -21,15 +21,15 @@ import sys ...@@ -21,15 +21,15 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_initialize_model_parallel(intra_layer_model_parallel_size): def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format( print('> testing initialize_model_parallel with size {} ...'.format(
intra_layer_model_parallel_size)) tensor_model_parallel_size))
intra_layer_model_parallel_size_ = min(intra_layer_model_parallel_size, tensor_model_parallel_size_ = min(tensor_model_parallel_size,
torch.distributed.get_world_size()) torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized() assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(intra_layer_model_parallel_size_) mpu.initialize_model_parallel(tensor_model_parallel_size_)
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
# Checks. # Checks.
...@@ -38,15 +38,15 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size): ...@@ -38,15 +38,15 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
assert rank == torch.distributed.get_rank(group=group) assert rank == torch.distributed.get_rank(group=group)
# Model parallel. # Model parallel.
world_size = intra_layer_model_parallel_size_ world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % intra_layer_model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == mpu.get_intra_layer_model_parallel_world_size() assert world_size == mpu.get_tensor_model_parallel_world_size()
assert rank == mpu.get_intra_layer_model_parallel_rank() assert rank == mpu.get_tensor_model_parallel_rank()
check(mpu.get_intra_layer_model_parallel_group(), world_size, rank) check(mpu.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel. # Data parallel.
world_size = torch.distributed.get_world_size() // intra_layer_model_parallel_size_ world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // intra_layer_model_parallel_size rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == mpu.get_data_parallel_world_size() assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank() assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank) check(mpu.get_data_parallel_group(), world_size, rank)
...@@ -59,20 +59,20 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size): ...@@ -59,20 +59,20 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_get_intra_layer_model_parallel_src_rank(intra_layer_model_parallel_size_): def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing get_intra_layer_model_parallel_src_rank with size {} ...'.format( print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
intra_layer_model_parallel_size_)) tensor_model_parallel_size_))
intra_layer_model_parallel_size = min(intra_layer_model_parallel_size_, tensor_model_parallel_size = min(tensor_model_parallel_size_,
torch.distributed.get_world_size()) torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized() assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
# Checks # Checks
src_rank = torch.distributed.get_rank() - mpu.get_intra_layer_model_parallel_rank() src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
assert mpu.get_intra_layer_model_parallel_src_rank() == src_rank assert mpu.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
...@@ -86,10 +86,10 @@ if __name__ == '__main__': ...@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel') print_separator('test initialize model parallel')
test_initialize_model_parallel(intra_layer_model_parallel_size) test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank') print_separator('test model parallel source rank')
test_get_intra_layer_model_parallel_src_rank(intra_layer_model_parallel_size) test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -26,14 +26,14 @@ import sys ...@@ -26,14 +26,14 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_parallel_embedding(intra_layer_model_parallel_size): def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'. print('> testing parallel embedding with model parallel size {} ...'.
format(intra_layer_model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 17 batch_size = 17
seq_length = 23 seq_length = 23
...@@ -80,16 +80,16 @@ def test_parallel_embedding(intra_layer_model_parallel_size): ...@@ -80,16 +80,16 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
assert error < 1.0e-12, 'error: {}'.format(error) assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // intra_layer_model_parallel_size, hidden_size // tensor_model_parallel_size,
1)[mpu.get_intra_layer_model_parallel_rank()] 1)[mpu.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format( print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error)) torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error) assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // intra_layer_model_parallel_size, vocab_size // tensor_model_parallel_size,
0)[mpu.get_intra_layer_model_parallel_rank()] 0)[mpu.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub( error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max() weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format( print(' error in grad (vocab parallel) on global rank {}: {}'.format(
...@@ -104,19 +104,19 @@ def test_parallel_embedding(intra_layer_model_parallel_size): ...@@ -104,19 +104,19 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_initialize_affine_weight(intra_layer_model_parallel_size): def test_initialize_affine_weight(tensor_model_parallel_size):
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel ' print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(intra_layer_model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * intra_layer_model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * intra_layer_model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
# --------------- # ---------------
# Column parallel # Column parallel
...@@ -131,7 +131,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size): ...@@ -131,7 +131,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
master_weight = torch.empty(output_size, input_size) master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight) torch.nn.init.normal_(master_weight)
rank = mpu.get_intra_layer_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff, my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone() dim=0)[rank].contiguous().clone()
...@@ -154,7 +154,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size): ...@@ -154,7 +154,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
master_weight = torch.empty(output_size, input_size) master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight) torch.nn.init.normal_(master_weight)
rank = mpu.get_intra_layer_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff, my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone() dim=1)[rank].contiguous().clone()
...@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module): ...@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight return self.weight
def test_column_parallel_linear(intra_layer_model_parallel_size): def test_column_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel ' print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(intra_layer_model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * intra_layer_model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * intra_layer_model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
# Network # Network
...@@ -219,7 +219,7 @@ def test_column_parallel_linear(intra_layer_model_parallel_size): ...@@ -219,7 +219,7 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A) dLdX = torch.matmul(dLdY, A)
rank = mpu.get_intra_layer_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff, my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone() dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
...@@ -250,20 +250,20 @@ def test_column_parallel_linear(intra_layer_model_parallel_size): ...@@ -250,20 +250,20 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
print(' >> passed the test :-)') print(' >> passed the test :-)')
def test_row_parallel_linear(intra_layer_model_parallel_size): def test_row_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel ' print('> testing RowParallelLinear with model parallel '
'size: {}'.format(intra_layer_model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * intra_layer_model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * intra_layer_model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
# Network # Network
...@@ -286,7 +286,7 @@ def test_row_parallel_linear(intra_layer_model_parallel_size): ...@@ -286,7 +286,7 @@ def test_row_parallel_linear(intra_layer_model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A) dLdX = torch.matmul(dLdY, A)
rank = mpu.get_intra_layer_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff, my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone() dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
...@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module): ...@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return self.weight return self.weight
def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_partition, def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length): sequence_length):
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
...@@ -352,17 +352,17 @@ def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_p ...@@ -352,17 +352,17 @@ def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_p
# Backward # Backward
loss.backward() loss.backward()
rank = mpu.get_intra_layer_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
return rank, hidden_size, intra_layer_model_parallel_size, loss, \ return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer attention_layer, identity_layer
def test_parallel_self_attention(intra_layer_model_parallel_size): def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel ' print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(intra_layer_model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
...@@ -370,14 +370,14 @@ def test_parallel_self_attention(intra_layer_model_parallel_size): ...@@ -370,14 +370,14 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hideen_size_1, intra_layer_model_parallel_size_1, loss_1, \ rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention( attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, intra_layer_model_parallel_size, loss, \ rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention( attention_layer, identity_layer = parallel_self_attention(
intra_layer_model_parallel_size, num_att_heads_per_partition, tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size assert hideen_size_1 == hidden_size
...@@ -389,7 +389,7 @@ def test_parallel_self_attention(intra_layer_model_parallel_size): ...@@ -389,7 +389,7 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
my_lin_grad_list = torch.split( my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad, attention_layer_1.query_key_value.weight.grad,
hidden_size // intra_layer_model_parallel_size, 0)[rank::intra_layer_model_parallel_size] hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0) my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub( error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max() attention_layer.query_key_value.weight.grad).abs().max()
...@@ -410,11 +410,11 @@ def test_parallel_self_attention(intra_layer_model_parallel_size): ...@@ -410,11 +410,11 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
print(' >> passed the test :-)') print(' >> passed the test :-)')
def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_partition, def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length): hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
...@@ -440,31 +440,31 @@ def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_part ...@@ -440,31 +440,31 @@ def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_part
# Backward # Backward
loss.backward() loss.backward()
rank = mpu.get_intra_layer_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
return rank, hidden_size, intra_layer_model_parallel_size, loss, \ return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer transformer_layer, identity_layer
def test_parallel_transformer_layer(intra_layer_model_parallel_size): def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel ' print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(intra_layer_model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hidden_size_1, intra_layer_model_parallel_size_1, loss_1, \ rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer( transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length) hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, intra_layer_model_parallel_size, loss, \ rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer( transformer_layer, identity_layer = parallel_transformer(
intra_layer_model_parallel_size, num_att_heads_per_partition, tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length) hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max() error = loss_1.sub(loss).abs().max()
...@@ -494,37 +494,37 @@ if __name__ == '__main__': ...@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight') print_separator('test initialize affine weight')
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(intra_layer_model_parallel_size) test_initialize_affine_weight(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test parallel embedding') print_separator('test parallel embedding')
test_parallel_embedding(intra_layer_model_parallel_size) test_parallel_embedding(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test column-parallel linear') print_separator('test column-parallel linear')
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(intra_layer_model_parallel_size) test_column_parallel_linear(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test row-parallel linear') print_separator('test row-parallel linear')
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(intra_layer_model_parallel_size) test_row_parallel_linear(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test parallel self-attention') print_separator('test parallel self-attention')
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_parallel_self_attention(intra_layer_model_parallel_size) test_parallel_self_attention(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test parallel transformer') print_separator('test parallel transformer')
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_parallel_transformer_layer(intra_layer_model_parallel_size) test_parallel_transformer_layer(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -21,14 +21,14 @@ import sys ...@@ -21,14 +21,14 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_set_cuda_rng_state(intra_layer_model_parallel_size): def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'. print('> testing set_rng_state with size {} ...'.
format(intra_layer_model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
size = 123 size = 123
seed = 1234 seed = 1234
...@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(intra_layer_model_parallel_size): ...@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(intra_layer_model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_cuda_rng_tracker(intra_layer_model_parallel_size): def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'. print('> testing cuda rng tracker with size {} ...'.
format(intra_layer_model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed_1 = 1234 seed_1 = 1234
seed_2 = 4321 seed_2 = 4321
...@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(intra_layer_model_parallel_size): ...@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(intra_layer_model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_intra_layer_model_parallel_cuda_manual_seed(intra_layer_model_parallel_size): def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'. print('> testing model parallel cuda manual seed with size {} ...'.
format(intra_layer_model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(intra_layer_model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
mpu.intra_layer_model_parallel_cuda_manual_seed(12345) mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345 assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 + assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_intra_layer_model_parallel_rank()) mpu.get_tensor_model_parallel_rank())
# Reset the tracker # Reset the tracker
mpu.get_cuda_rng_tracker().reset() mpu.get_cuda_rng_tracker().reset()
...@@ -185,20 +185,20 @@ if __name__ == '__main__': ...@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test set rng state') print_separator('test set rng state')
test_set_cuda_rng_state(intra_layer_model_parallel_size) test_set_cuda_rng_state(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker') print_separator('test cuda rng tracker')
test_cuda_rng_tracker(intra_layer_model_parallel_size) test_cuda_rng_tracker(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
intra_layer_model_parallel_size = 1 tensor_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed') print_separator('test model parallel cuda manual seed')
test_intra_layer_model_parallel_cuda_manual_seed(intra_layer_model_parallel_size) test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
intra_layer_model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model): ...@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file. # Read the sample file and open the output file.
assert args.sample_input_file is not None, \ assert args.sample_input_file is not None, \
'sample input file is not provided.' 'sample input file is not provided.'
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
...@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model): ...@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group()) torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0 terminate_runs = 0
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
input_pos += 1 input_pos += 1
if input_pos == input_count: if input_pos == input_count:
...@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model): ...@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model):
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_intra_layer_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
...@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model): ...@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model):
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
...@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model): ...@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model):
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group()) torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1 context_count += 1
...@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group()) torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
terminate_runs = 0 terminate_runs = 0
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text: while not raw_text:
...@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_intra_layer_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1: if terminate_runs == 1:
...@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_intra_layer_model_parallel_rank() == 0 and \ if mpu.get_tensor_model_parallel_rank() == 0 and \
counter % print_frequency == 0: counter % print_frequency == 0:
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
...@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens)[len(raw_text):] decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
...@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24):
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group()) torch.distributed.barrier(group=mpu.get_tensor_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
input("\nPress any key to continue >>>") input("\nPress any key to continue >>>")
...@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens): ...@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens):
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor, torch.distributed.broadcast(context_length_tensor,
mpu.get_intra_layer_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, torch.distributed.broadcast(context_tokens_tensor,
mpu.get_intra_layer_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
......
...@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args): ...@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \ multiple = args.make_vocab_size_divisible_by * \
args.intra_layer_model_parallel_size args.tensor_model_parallel_size
while (after % multiple) != 0: while (after % multiple) != 0:
after += 1 after += 1
if args.rank == 0: if args.rank == 0:
......
...@@ -124,10 +124,10 @@ def get_model(model_provider_func): ...@@ -124,10 +124,10 @@ def get_model(model_provider_func):
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (intra-layer, inter-layer) ' print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format( 'model parallel rank ({}, {}): {}'.format(
mpu.get_intra_layer_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True) sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation. # GPU allocation.
...@@ -166,8 +166,8 @@ def get_optimizer(model): ...@@ -166,8 +166,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set. # Add model parallel attribute if it is not set.
for param_group in param_groups: for param_group in param_groups:
for param in param_group['params']: for param in param_group['params']:
if not hasattr(param, 'intra_layer_model_parallel'): if not hasattr(param, 'tensor_model_parallel'):
param.intra_layer_model_parallel = False param.tensor_model_parallel = False
# Use Adam. # Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay, optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
...@@ -260,7 +260,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -260,7 +260,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev=tensor_recv_prev, tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next, tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next, tensor_recv_next=tensor_recv_next,
group=mpu.get_inter_layer_model_parallel_group()) group=mpu.get_pipeline_model_parallel_group())
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
...@@ -304,7 +304,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -304,7 +304,7 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad() optimizer.zero_grad()
# Compute number of microbatches in a minibatch. # Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline = args.inter_layer_model_parallel_size \ num_microbatches_to_pipeline = args.pipeline_model_parallel_size \
if args.use_pipelining else 1 if args.use_pipelining else 1
input_tensors = [] input_tensors = []
...@@ -313,7 +313,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -313,7 +313,7 @@ def train_step(forward_step_func, data_iterator,
# Run forward pass for all microbatches in minibatch. # Run forward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline): for i in range(num_microbatches_to_pipeline):
if not mpu.is_inter_layer_first_stage(): if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate( input_tensor, _ = communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator,
output_tensor = forward_step_func(data_iterator, model, input_tensor) output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop() timers('forward').stop()
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss output_tensor = loss
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
...@@ -346,7 +346,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -346,7 +346,7 @@ def train_step(forward_step_func, data_iterator,
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
output_grad_tensor = None output_grad_tensor = None
else: else:
_, output_grad_tensor = communicate( _, output_grad_tensor = communicate(
...@@ -362,7 +362,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -362,7 +362,7 @@ def train_step(forward_step_func, data_iterator,
backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor) backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
timers('backward').stop() timers('backward').stop()
if not mpu.is_inter_layer_first_stage(): if not mpu.is_pipeline_first_stage():
communicate( communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_grad_tensor, tensor_send_prev=input_grad_tensor,
...@@ -383,8 +383,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -383,8 +383,8 @@ def train_step(forward_step_func, data_iterator,
timers('backward-master-grad').stop() timers('backward-master-grad').stop()
# All-reduce across first and last stages. # All-reduce across first and last stages.
if (mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage()) and \ if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.inter_layer_model_parallel_size > 1: args.pipeline_model_parallel_size > 1:
unwrapped_model = model unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)): while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
...@@ -421,7 +421,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -421,7 +421,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
# Average loss across microbatches. # Average loss across microbatches.
loss_reduced = {} loss_reduced = {}
for key in losses_reduced[0]: for key in losses_reduced[0]:
...@@ -604,7 +604,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -604,7 +604,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
if not mpu.is_inter_layer_first_stage(): if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate( input_tensor, _ = communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
...@@ -616,7 +616,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -616,7 +616,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Forward evaluation. # Forward evaluation.
output_tensor = forward_step_func(data_iterator, model, input_tensor) output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
_, loss_dict = output_tensor _, loss_dict = output_tensor
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
...@@ -671,7 +671,7 @@ def build_train_valid_test_data_iterators( ...@@ -671,7 +671,7 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...') print_rank_0('> building train, validation, and test datasets ...')
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_intra_layer_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
# Rank, size, and global batch size. # Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size() data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size global_batch_size = args.batch_size * data_parallel_size
...@@ -709,8 +709,8 @@ def build_train_valid_test_data_iterators( ...@@ -709,8 +709,8 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(flags, torch.distributed.broadcast(flags,
mpu.get_intra_layer_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[0].item() args.do_train = flags[0].item()
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
......
...@@ -58,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -58,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters.""" """Print min, max, and norm of all parameters."""
index = 0 index = 0
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
string = 'iteration, rank, index, intra-layer-model-parallel, min, max, norm\n' string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer): if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer optimizer_ = optimizer.optimizer
...@@ -69,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -69,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max() max_ = param.data.max()
norm = param.data.norm() norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.intra_layer_model_parallel)) iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True) print(string, flush=True)
......
...@@ -34,12 +34,12 @@ def model_provider(): ...@@ -34,12 +34,12 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
if args.inter_layer_model_parallel_size > 1: if args.pipeline_model_parallel_size > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
num_tokentypes=2) num_tokentypes=2)
elif mpu.is_inter_layer_last_stage(): elif mpu.is_pipeline_last_stage():
model = BertModelLastStage( model = BertModelLastStage(
num_tokentypes=2, num_tokentypes=2,
add_binary_head=True, add_binary_head=True,
...@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers('batch generator').stop() timers('batch generator').stop()
# Forward pass through the model. # Forward pass through the model.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
assert input_tensor is None assert input_tensor is None
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types, output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels) lm_labels=lm_labels)
else: else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types) output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_inter_layer_last_stage(): elif mpu.is_pipeline_last_stage():
assert input_tensor is not None assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels) output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else: else:
assert input_tensor is not None assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask) output_tensor = model(input_tensor, padding_mask)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
......
...@@ -33,11 +33,11 @@ def model_provider(): ...@@ -33,11 +33,11 @@ def model_provider():
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
args = get_args() args = get_args()
if args.inter_layer_model_parallel_size > 1: if args.pipeline_model_parallel_size > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0) model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_inter_layer_last_stage(): elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage( model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True) num_tokentypes=0, parallel_output=True)
else: else:
...@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers('batch generator').stop() timers('batch generator').stop()
# Forward pass through the model. # Forward pass through the model.
if mpu.is_inter_layer_first_stage(): if mpu.is_pipeline_first_stage():
assert input_tensor is None assert input_tensor is None
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(tokens, position_ids, attention_mask,
labels=labels) labels=labels)
else: else:
output_tensor = model(tokens, position_ids, attention_mask) output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_inter_layer_last_stage(): elif mpu.is_pipeline_last_stage():
assert input_tensor is not None assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels) output_tensor = model(input_tensor, attention_mask, labels=labels)
else: else:
assert input_tensor is not None assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask) output_tensor = model(input_tensor, attention_mask)
if mpu.is_inter_layer_last_stage(): if mpu.is_pipeline_last_stage():
losses = output_tensor.float() losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment