Commit 7abd3e90 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Pipeline parallelism implementation with periodic full-pipeline syncs

Also includes following changes for inter-layer model-parallel implementation:
- Refactoring of model implementations
- Training loop changes to support inter-layer communication using `ring_exchange`
- New groups for inter-layer communication
- Checkpoint changes
- Command line arguments
parent 28cd66e1
......@@ -29,7 +29,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
input_parallel = mpu.copy_to_intra_layer_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
......@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output:
return logits_parallel
return mpu.gather_from_model_parallel_region(logits_parallel)
return mpu.gather_from_intra_layer_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
......@@ -54,12 +54,24 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
# Language model.
language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func,
init_method=init_method,
output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
args = [attention_mask_func, init_method, scaled_init_method]
kwargs = {}
cls = None
if mpu.is_inter_layer_first_stage() and mpu.is_inter_layer_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_pooler'] = add_pooler
elif mpu.is_inter_layer_first_stage() and not mpu.is_inter_layer_last_stage():
cls = TransformerLanguageModelFirstStage
kwargs['num_tokentypes'] = num_tokentypes
elif not mpu.is_inter_layer_first_stage() and mpu.is_inter_layer_last_stage():
cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler
else:
cls = TransformerLanguageModelIntermediateStage
# Language model.
language_model = cls(*args, **kwargs)
# key used for checkpoints.
language_model_key = 'language_model'
......@@ -118,9 +130,12 @@ class Embedding(MegatronModule):
self.init_method = init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
vocab_size, self.hidden_size,
init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
......@@ -160,6 +175,7 @@ class Embedding(MegatronModule):
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
args = get_args()
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
......@@ -241,7 +257,7 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
class TransformerLanguageModelBase(MegatronModule):
"""Transformer language model.
Arguments:
......@@ -266,7 +282,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__()
super(TransformerLanguageModelBase, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
......@@ -274,41 +290,47 @@ class TransformerLanguageModel(MegatronModule):
self.init_method = init_method
self.add_pooler = add_pooler
# Embeddings
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self._embedding_key = 'embedding'
# Embeddings.
if mpu.is_inter_layer_first_stage():
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self._embedding_key = 'embedding'
# Transformer
# Transformer.
self.transformer = ParallelTransformer(
attention_mask_func, self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'
# Pooler
if self.add_pooler:
# Pooler.
if mpu.is_inter_layer_last_stage() and self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
def forward(self, language_model_input, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
# Embeddings.
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
if mpu.is_inter_layer_first_stage():
(input_ids, position_ids) = language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
transformer_input = embedding_output
else:
transformer_input = language_model_input
# Transformer.
transformer_output = self.transformer(embedding_output,
transformer_output = self.transformer(transformer_input,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if self.add_pooler:
if mpu.is_inter_layer_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output,
pooling_sequence_index)
return transformer_output, pooled_output
......@@ -320,13 +342,14 @@ class TransformerLanguageModel(MegatronModule):
"""For easy load."""
state_dict_ = {}
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_inter_layer_first_stage():
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_pooler:
if mpu.is_inter_layer_last_stage() and self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
......@@ -337,15 +360,16 @@ class TransformerLanguageModel(MegatronModule):
"""Customized load."""
# Embedding.
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
if mpu.is_inter_layer_first_stage():
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
......@@ -359,8 +383,118 @@ class TransformerLanguageModel(MegatronModule):
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.add_pooler:
if mpu.is_inter_layer_last_stage() and self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
)
class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(TransformerLanguageModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
return super(TransformerLanguageModelIntermediateStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
)
......@@ -68,8 +68,7 @@ class MultipleChoice(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
_, pooled_output = self.language_model(input_ids, position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
......
......@@ -19,7 +19,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
assert args.intra_layer_model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
......@@ -172,8 +172,7 @@ class IREncoderBertModel(MegatronModule):
position_ids = bert_position_ids(input_ids)
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
input_ids, position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
......
......@@ -130,7 +130,7 @@ class ParallelSelfAttention(MegatronModule):
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
world_size = mpu.get_intra_layer_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
......@@ -504,13 +504,15 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers:
self.num_layers = args.num_layers
self.num_unique_layers = args.num_unique_layers
if self.num_unique_layers is None:
# Number of layers.
self.num_layers = args.num_layers // args.inter_layer_model_parallel_size
# TODO: Need to do something different in case self.num_layers != self.num_unique_layers?
if args.num_unique_layers is None:
self.num_unique_layers = self.num_layers
assert self.num_layers % self.num_unique_layers == 0, \
'number of layers should be divisible by number of unique layers'
else:
self.num_unique_layers = args.num_unique_layers // args.inter_layer_model_parallel_size
assert self.num_layers == self.num_unique_layers, \
'number of layers should be equal to the number of unique layers'
self.param_sharing_style = args.param_sharing_style
# Transformer layers.
......@@ -518,8 +520,9 @@ class ParallelTransformer(MegatronModule):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number)
offset = mpu.get_inter_layer_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1) for i in range(self.num_unique_layers)])
[build_layer(i + 1 + offset) for i in range(self.num_unique_layers)])
# Print layer ordering.
if self.num_layers != self.num_unique_layers:
......@@ -530,10 +533,11 @@ class ParallelTransformer(MegatronModule):
'{:3d}'.format(i, self._get_layer_index(i)),
flush=True)
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if mpu.is_inter_layer_last_stage():
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _get_layer_index(self, layer_number):
if self.param_sharing_style == 'grouped':
......@@ -606,7 +610,10 @@ class ParallelTransformer(MegatronModule):
hidden_states = hidden_states.transpose(0, 1).contiguous()
# Final layer norm.
output = self.final_layernorm(hidden_states)
if mpu.is_inter_layer_last_stage():
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
if get_key_value:
output = [output, presents]
......
......@@ -26,10 +26,17 @@ from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank, set_model_parallel_rank
from .initialize import get_model_parallel_src_rank
from .initialize import get_model_parallel_world_size, set_model_parallel_world_size
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_inter_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank, set_intra_layer_model_parallel_rank
from .initialize import get_inter_layer_model_parallel_rank, set_inter_layer_model_parallel_rank
from .initialize import is_inter_layer_first_stage, is_inter_layer_last_stage
from .initialize import get_intra_layer_model_parallel_src_rank
from .initialize import get_inter_layer_model_parallel_src_rank
from .initialize import get_intra_layer_model_parallel_world_size, set_intra_layer_model_parallel_world_size
from .initialize import get_inter_layer_model_parallel_world_size, set_inter_layer_model_parallel_world_size
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
......@@ -38,15 +45,15 @@ from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .mappings import copy_to_intra_layer_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import intra_layer_model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .utils import divide
......
......@@ -16,9 +16,9 @@
import torch
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
from .utils import VocabUtility
......@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
group=get_intra_layer_model_parallel_group())
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
rank = get_intra_layer_model_parallel_rank()
world_size = get_intra_layer_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size)
......@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
group=get_intra_layer_model_parallel_group())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
......@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
group=get_intra_layer_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
......
......@@ -15,9 +15,9 @@
import torch
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_src_rank
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_src_rank
_MAX_DATA_DIM = 4
......@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_model_parallel_rank() == 0:
if get_intra_layer_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
......@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(),
group=get_model_parallel_group())
torch.distributed.broadcast(sizes_cuda, get_intra_layer_model_parallel_src_rank(),
group=get_intra_layer_model_parallel_group())
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
......@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data)
# Pack on rank zero.
if get_model_parallel_rank() == 0:
if get_intra_layer_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
......@@ -100,9 +100,9 @@ def broadcast_data(keys, data, datatype):
device=torch.cuda.current_device(),
dtype=datatype)
# Boradcast
torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(),
group=get_model_parallel_group())
# Broadcast
torch.distributed.broadcast(flatten_data, get_intra_layer_model_parallel_src_rank(),
group=get_intra_layer_model_parallel_group())
# Unpack
output = {}
......
......@@ -28,8 +28,9 @@ try:
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_inter_layer_first_stage
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_rank
def l2_grad_clipper(parameters, max_norm):
......@@ -43,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_model_parallel_rank() == 0)
mp_rank_is_zero = (get_intra_layer_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads))
lambda p: p.intra_layer_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
......@@ -71,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm):
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2):
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
......@@ -90,13 +91,27 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_inter_layer_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
......@@ -105,16 +120,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
#elif norm_type == 2:
# total_norm = l2_grad_clipper(parameters, max_norm)
else:
total_norm = 0
for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0):
if p.intra_layer_model_parallel or (get_intra_layer_model_parallel_rank() == 0):
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
......
......@@ -21,14 +21,22 @@ import torch
from .utils import ensure_divisibility
# Model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
_INTRA_LAYER_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_INTER_LAYER_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and inter-layer) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None
_MPU_RANK = None
_MPU_INTRA_LAYER_WORLD_SIZE = None
_MPU_INTER_LAYER_WORLD_SIZE = None
_MPU_INTRA_LAYER_RANK = None
_MPU_INTER_LAYER_RANK = None
def is_unitialized():
......@@ -36,60 +44,120 @@ def is_unitialized():
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_):
def initialize_model_parallel(intra_layer_model_parallel_size_=1,
inter_layer_model_parallel_size_=1):
"""
Initialize model data parallel groups.
Arguments:
model_parallel_size: number of GPUs used to parallelize model.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel grous as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
intra_layer_model_parallel_size: number of GPUs used to parallelize model intra-layer.
inter_layer_model_parallel_size: number of GPUs used to parallelize model inter-layer.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model intra-layer, and 4 GPUs to parallelize
the model inter-layer. The present function will
create 8 intra-layer model-parallel groups, 4 inter-layer model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 intra-layer model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 inter-layer model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print('> initializing model parallel with size {}'.format(
model_parallel_size_))
print('> initializing intra-layer model parallel with size {}'.format(
intra_layer_model_parallel_size_))
print('> initializing inter-layer model parallel with size {}'.format(
inter_layer_model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size)
intra_layer_model_parallel_size = min(intra_layer_model_parallel_size_, world_size)
inter_layer_model_parallel_size = min(inter_layer_model_parallel_size_, world_size)
ensure_divisibility(world_size,
intra_layer_model_parallel_size * inter_layer_model_parallel_size)
data_parallel_size = world_size // (intra_layer_model_parallel_size *
inter_layer_model_parallel_size)
num_intra_layer_model_parallel_groups = world_size // intra_layer_model_parallel_size
num_inter_layer_model_parallel_groups = world_size // inter_layer_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank()
# Build the data parallel groups.
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
# Build the model parallel groups.
all_data_parallel_group_ranks = []
for i in range(inter_layer_model_parallel_size):
start_rank = i * num_inter_layer_model_parallel_groups
end_rank = (i + 1) * num_inter_layer_model_parallel_groups
for j in range(intra_layer_model_parallel_size):
ranks = range(start_rank + j, end_rank,
intra_layer_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized'
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size,
(i + 1) * model_parallel_size)
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the intra-layer model-parallel groups.
global _INTRA_LAYER_MODEL_PARALLEL_GROUP
assert _INTRA_LAYER_MODEL_PARALLEL_GROUP is None, \
'intra-layer model parallel group is already initialized'
for i in range(num_intra_layer_model_parallel_groups):
ranks = range(i * intra_layer_model_parallel_size,
(i + 1) * intra_layer_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_INTRA_LAYER_MODEL_PARALLEL_GROUP = group
# Build the inter-layer model-parallel groups and embedding groups
# (first and last rank in each inter-layer model-parallel group).
global _INTER_LAYER_MODEL_PARALLEL_GROUP
assert _INTER_LAYER_MODEL_PARALLEL_GROUP is None, \
'inter-layer model parallel group is already initialized'
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
for i in range(num_inter_layer_model_parallel_groups):
ranks = range(i, world_size,
num_inter_layer_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_INTER_LAYER_MODEL_PARALLEL_GROUP = group
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
if _INTRA_LAYER_MODEL_PARALLEL_GROUP is None or \
_INTER_LAYER_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
......@@ -101,6 +169,20 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def get_intra_layer_model_parallel_group():
"""Get the intra-layer model parallel group the caller rank belongs to."""
assert _INTRA_LAYER_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _INTRA_LAYER_MODEL_PARALLEL_GROUP
def get_inter_layer_model_parallel_group():
"""Get the inter-layer model parallel group the caller rank belongs to."""
assert _INTER_LAYER_MODEL_PARALLEL_GROUP is not None, \
'inter_layer_model parallel group is not initialized'
return _INTER_LAYER_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
......@@ -108,42 +190,97 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
global _MPU_WORLD_SIZE
_MPU_WORLD_SIZE = world_size
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, \
'embedding group is not initialized'
return _EMBEDDING_GROUP
def set_intra_layer_model_parallel_world_size(world_size):
"""Set the intra-layer model parallel size"""
global _MPU_INTRA_LAYER_WORLD_SIZE
_MPU_INTRA_LAYER_WORLD_SIZE = world_size
def set_inter_layer_model_parallel_world_size(world_size):
"""Set the inter-layer model parallel size"""
global _MPU_INTER_LAYER_WORLD_SIZE
_MPU_INTER_LAYER_WORLD_SIZE = world_size
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global _MPU_WORLD_SIZE
if _MPU_WORLD_SIZE is not None:
return _MPU_WORLD_SIZE
return torch.distributed.get_world_size(group=get_model_parallel_group())
def get_intra_layer_model_parallel_world_size():
"""Return world size for the intra-layer model parallel group."""
global _MPU_INTRA_LAYER_WORLD_SIZE
if _MPU_INTRA_LAYER_WORLD_SIZE is not None:
return _MPU_INTRA_LAYER_WORLD_SIZE
return torch.distributed.get_world_size(group=get_intra_layer_model_parallel_group())
def set_model_parallel_rank(rank):
"""Set model parallel rank."""
global _MPU_RANK
_MPU_RANK = rank
def get_inter_layer_model_parallel_world_size():
"""Return world size for the inter-layer model parallel group."""
global _MPU_INTER_LAYER_WORLD_SIZE
if _MPU_INTER_LAYER_WORLD_SIZE is not None:
return _MPU_INTER_LAYER_WORLD_SIZE
return torch.distributed.get_world_size(group=get_inter_layer_model_parallel_group())
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
global _MPU_RANK
if _MPU_RANK is not None:
return _MPU_RANK
return torch.distributed.get_rank(group=get_model_parallel_group())
def set_intra_layer_model_parallel_rank(rank):
"""Set intra-layer model parallel rank."""
global _MPU_INTRA_LAYER_RANK
_MPU_INTRA_LAYER_RANK = rank
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zeor
in the model parallel group."""
def set_inter_layer_model_parallel_rank(rank):
"""Set inter-layer model parallel rank."""
global _MPU_INTER_LAYER_RANK
_MPU_INTER_LAYER_RANK = rank
def get_intra_layer_model_parallel_rank():
"""Return my rank for the intra-layer model parallel group."""
global _MPU_INTRA_LAYER_RANK
if _MPU_INTRA_LAYER_RANK is not None:
return _MPU_INTRA_LAYER_RANK
return torch.distributed.get_rank(group=get_intra_layer_model_parallel_group())
def get_inter_layer_model_parallel_rank():
"""Return my rank for the inter-layer model parallel group."""
global _MPU_INTER_LAYER_RANK
if _MPU_INTER_LAYER_RANK is not None:
return _MPU_INTER_LAYER_RANK
return torch.distributed.get_rank(group=get_inter_layer_model_parallel_group())
def is_inter_layer_first_stage():
"""Return True if in the first inter-layer model-parallel stage, False otherwise."""
return get_inter_layer_model_parallel_rank() == 0
def is_inter_layer_last_stage():
"""Return True if in the last inter-layer model-parallel stage, False otherwise."""
return get_inter_layer_model_parallel_rank() == (
get_inter_layer_model_parallel_world_size() - 1)
def get_intra_layer_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank
in the intra-layer model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size()
local_world_size = get_intra_layer_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_inter_layer_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank
in the inter-layer model parallel group."""
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
local_world_size = get_inter_layer_model_parallel_world_size()
return global_rank % (global_world_size // local_world_size)
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
......@@ -156,7 +293,9 @@ def get_data_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _INTRA_LAYER_MODEL_PARALLEL_GROUP
_INTRA_LAYER_MODEL_PARALLEL_GROUP = None
global _INTER_LAYER_MODEL_PARALLEL_GROUP
_INTER_LAYER_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
......@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
from .mappings import copy_to_intra_layer_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True
weight.intra_layer_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
......@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight.model_parallel = True
weight.intra_layer_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
......@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
world_size = get_intra_layer_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
......@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
self.intra_layer_model_parallel_size = get_intra_layer_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
self.model_parallel_size)
self.num_embeddings, get_intra_layer_model_parallel_rank(),
self.intra_layer_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
......@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim=0, stride=1)
def forward(self, input_):
if self.model_parallel_size > 1:
if self.intra_layer_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
......@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
if self.model_parallel_size > 1:
if self.intra_layer_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
output = reduce_from_intra_layer_model_parallel_region(output_parallel)
return output
......@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
world_size = get_intra_layer_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
......@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.model_parallel = True
self.bias.intra_layer_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero.
......@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
input_parallel = copy_to_intra_layer_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
output = gather_from_intra_layer_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
......@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
world_size = get_intra_layer_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
......@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
input_parallel = scatter_to_intra_layer_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
output_ = reduce_from_intra_layer_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
......
......@@ -15,7 +15,7 @@
import torch
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_group, get_intra_layer_model_parallel_world_size, get_intra_layer_model_parallel_rank
from .utils import split_tensor_along_last_dim
......@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_model_parallel_world_size()==1:
if get_intra_layer_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_model_parallel_group())
torch.distributed.all_reduce(input_, group=get_intra_layer_model_parallel_group())
return input_
......@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_model_parallel_world_size()
world_size = get_intra_layer_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
......@@ -45,7 +45,7 @@ def _split(input_):
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_model_parallel_rank()
rank = get_intra_layer_model_parallel_rank()
output = input_list[rank].contiguous()
return output
......@@ -54,18 +54,18 @@ def _split(input_):
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_model_parallel_world_size()
world_size = get_intra_layer_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_model_parallel_rank()
rank = get_intra_layer_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group())
torch.distributed.all_gather(tensor_list, input_, group=get_intra_layer_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
......@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_):
def copy_to_intra_layer_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
def reduce_from_intra_layer_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
def scatter_to_intra_layer_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
def gather_from_intra_layer_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
......@@ -28,13 +28,13 @@ from megatron import get_args
from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .initialize import get_intra_layer_model_parallel_group
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'intra-layer-model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
......@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_model_parallel_world_size()
start_index = partition_size * get_model_parallel_rank()
partition_size = torch.numel(data) // get_intra_layer_model_parallel_world_size()
start_index = partition_size * get_intra_layer_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_model_parallel_world_size()
world_size = get_intra_layer_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
......@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor,
group=get_model_parallel_group())
group=get_intra_layer_model_parallel_group())
return gathered
......@@ -204,7 +204,7 @@ def get_cuda_rng_tracker():
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
def intra_layer_model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
......@@ -215,15 +215,15 @@ def model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
example for dropout in the non-intra-layer-model-parallel regions.
intra-layer-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + get_model_parallel_rank()
intra_layer_model_parallel_seed = offset + get_intra_layer_model_parallel_rank()
# Data parallel gets the original sedd.
data_parallel_seed = seed
......@@ -231,15 +231,15 @@ def model_parallel_cuda_manual_seed(seed):
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_model_parallel_rank(),
get_data_parallel_rank(), model_parallel_seed,
torch.distributed.get_rank(), get_intra_layer_model_parallel_rank(),
get_data_parallel_rank(), intra_layer_model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
model_parallel_seed)
intra_layer_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
......
......@@ -36,7 +36,7 @@ def set_random_seed(seed):
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
mpu.intra_layer_model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
......
......@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
logits_parallel = mpu.scatter_to_model_parallel_region(logits)
logits_parallel = mpu.scatter_to_intra_layer_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
......@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return loss, identity.weight.grad
def test_cross_entropy(model_parallel_size):
def test_cross_entropy(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(model_parallel_size))
format(intra_layer_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size
vocab_size = vocab_size_per_partition * intra_layer_model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
......@@ -89,7 +89,7 @@ def test_cross_entropy(model_parallel_size):
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
mpu.destroy_intra_layer_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
......@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(model_parallel_size)
model_parallel_size *= 2
test_cross_entropy(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
......@@ -24,15 +24,15 @@ import sys
sys.path.append("../..")
def test_boradcast_data(model_parallel_size):
def test_broadcast_data(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing boradcast_data with model parallel size {} ...'.
format(model_parallel_size))
print('> testing broadcast_data with model parallel size {} ...'.
format(intra_layer_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
model_parallel_size = mpu.get_model_parallel_world_size()
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1],
......@@ -48,7 +48,7 @@ def test_boradcast_data(model_parallel_size):
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if mpu.get_model_parallel_rank() != 0:
if mpu.get_intra_layer_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
......@@ -69,7 +69,7 @@ def test_boradcast_data(model_parallel_size):
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
mpu.destroy_model_parallel()
mpu.destroy_intra_layer_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
......@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size)
model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test test broadcast data')
test_broadcast_data(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
......@@ -21,15 +21,15 @@ import sys
sys.path.append("../..")
def test_initialize_model_parallel(model_parallel_size):
def test_initialize_model_parallel(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
model_parallel_size))
model_parallel_size_ = min(model_parallel_size,
intra_layer_model_parallel_size))
intra_layer_model_parallel_size_ = min(intra_layer_model_parallel_size,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_)
mpu.initialize_model_parallel(intra_layer_model_parallel_size_)
assert mpu.model_parallel_is_initialized()
# Checks.
......@@ -38,15 +38,15 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank)
world_size = intra_layer_model_parallel_size_
rank = torch.distributed.get_rank() % intra_layer_model_parallel_size_
assert world_size == mpu.get_intra_layer_model_parallel_world_size()
assert rank == mpu.get_intra_layer_model_parallel_rank()
check(mpu.get_intra_layer_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size
world_size = torch.distributed.get_world_size() // intra_layer_model_parallel_size_
rank = torch.distributed.get_rank() // intra_layer_model_parallel_size
assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank)
......@@ -59,20 +59,20 @@ def test_initialize_model_parallel(model_parallel_size):
print('>> passed the test :-)')
def test_get_model_parallel_src_rank(model_parallel_size_):
def test_get_intra_layer_model_parallel_src_rank(intra_layer_model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_model_parallel_src_rank with size {} ...'.format(
model_parallel_size_))
model_parallel_size = min(model_parallel_size_,
print('> testing get_intra_layer_model_parallel_src_rank with size {} ...'.format(
intra_layer_model_parallel_size_))
intra_layer_model_parallel_size = min(intra_layer_model_parallel_size_,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
assert mpu.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank
src_rank = torch.distributed.get_rank() - mpu.get_intra_layer_model_parallel_rank()
assert mpu.get_intra_layer_model_parallel_src_rank() == src_rank
# Reset groups
mpu.destroy_model_parallel()
......@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(model_parallel_size)
test_initialize_model_parallel(intra_layer_model_parallel_size)
print_separator('test model parallel source rank')
test_get_model_parallel_src_rank(model_parallel_size)
model_parallel_size *= 2
test_get_intra_layer_model_parallel_src_rank(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
......@@ -26,14 +26,14 @@ import sys
sys.path.append("../..")
def test_parallel_embedding(model_parallel_size):
def test_parallel_embedding(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size))
format(intra_layer_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
batch_size = 17
seq_length = 23
......@@ -80,16 +80,16 @@ def test_parallel_embedding(model_parallel_size):
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size,
1)[mpu.get_model_parallel_rank()]
hidden_size // intra_layer_model_parallel_size,
1)[mpu.get_intra_layer_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size,
0)[mpu.get_model_parallel_rank()]
vocab_size // intra_layer_model_parallel_size,
0)[mpu.get_intra_layer_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
......@@ -104,19 +104,19 @@ def test_parallel_embedding(model_parallel_size):
print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size):
def test_initialize_affine_weight(intra_layer_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
'size: {}'.format(intra_layer_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
input_size = input_size_coeff * intra_layer_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
output_size = output_size_coeff * intra_layer_model_parallel_size
# ---------------
# Column parallel
......@@ -131,7 +131,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_intra_layer_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
......@@ -154,7 +154,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_intra_layer_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
......@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight
def test_column_parallel_linear(model_parallel_size):
def test_column_parallel_linear(intra_layer_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
'size: {}'.format(intra_layer_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
input_size = input_size_coeff * intra_layer_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
output_size = output_size_coeff * intra_layer_model_parallel_size
batch_size = 7
# Network
......@@ -219,7 +219,7 @@ def test_column_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_intra_layer_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
......@@ -250,20 +250,20 @@ def test_column_parallel_linear(model_parallel_size):
print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size):
def test_row_parallel_linear(intra_layer_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
'size: {}'.format(intra_layer_model_parallel_size))
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
input_size = input_size_coeff * intra_layer_model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
output_size = output_size_coeff * intra_layer_model_parallel_size
batch_size = 7
# Network
......@@ -286,7 +286,7 @@ def test_row_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
rank = mpu.get_intra_layer_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
......@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
......@@ -352,17 +352,17 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
rank = mpu.get_intra_layer_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
return rank, hidden_size, intra_layer_model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size):
def test_parallel_self_attention(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size))
'size: {}'.format(intra_layer_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
......@@ -370,14 +370,14 @@ def test_parallel_self_attention(model_parallel_size):
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
rank_1, hideen_size_1, intra_layer_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
rank, hidden_size, intra_layer_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition,
intra_layer_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
......@@ -389,7 +389,7 @@ def test_parallel_self_attention(model_parallel_size):
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
hidden_size // intra_layer_model_parallel_size, 0)[rank::intra_layer_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
......@@ -410,11 +410,11 @@ def test_parallel_self_attention(model_parallel_size):
print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
......@@ -440,31 +440,31 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
rank = mpu.get_intra_layer_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
return rank, hidden_size, intra_layer_model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size):
def test_parallel_transformer_layer(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size))
'size: {}'.format(intra_layer_model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
rank_1, hidden_size_1, intra_layer_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
rank, hidden_size, intra_layer_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition,
intra_layer_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
......@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size)
model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_initialize_affine_weight(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size)
model_parallel_size *= 2
test_parallel_embedding(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
print_separator('test column-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size)
model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_column_parallel_linear(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
print_separator('test row-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size)
model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_row_parallel_linear(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
print_separator('test parallel self-attention')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size)
model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_parallel_self_attention(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
print_separator('test parallel transformer')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size)
model_parallel_size *= 2
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
test_parallel_transformer_layer(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
......@@ -21,14 +21,14 @@ import sys
sys.path.append("../..")
def test_set_cuda_rng_state(model_parallel_size):
def test_set_cuda_rng_state(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(model_parallel_size))
format(intra_layer_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
size = 123
seed = 1234
......@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(model_parallel_size):
print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size):
def test_cuda_rng_tracker(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size))
format(intra_layer_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
......@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(model_parallel_size):
print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size):
def test_intra_layer_model_parallel_cuda_manual_seed(intra_layer_model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size))
format(intra_layer_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.initialize_model_parallel(intra_layer_model_parallel_size)
intra_layer_model_parallel_size = mpu.get_intra_layer_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345)
mpu.intra_layer_model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank())
mpu.get_intra_layer_model_parallel_rank())
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
......@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size)
model_parallel_size *= 2
test_set_cuda_rng_state(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size)
model_parallel_size *= 2
test_cuda_rng_tracker(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
intra_layer_model_parallel_size = 1
while intra_layer_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2
test_intra_layer_model_parallel_cuda_manual_seed(intra_layer_model_parallel_size)
intra_layer_model_parallel_size *= 2
......@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
......@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model):
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
terminate_runs = 0
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
if input_pos == input_count:
......@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model):
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
......@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
......@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model):
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
context_count += 1
......@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24):
model.eval()
with torch.no_grad():
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
terminate_runs = 0
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text:
......@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
......@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and \
if mpu.get_intra_layer_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
......@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
......@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24):
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
torch.distributed.barrier(group=mpu.get_intra_layer_model_parallel_group())
context_count += 1
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
......@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens):
context_length_tensor = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
......
......@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size
args.intra_layer_model_parallel_size
while (after % multiple) != 0:
after += 1
if args.rank == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment