Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 4
...@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes = [0 for _ in range(max_dim) for _ in keys] sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero. # Pack the sizes on rank zero.
if get_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
offset = 0 offset = 0
for key in keys: for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
...@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast. # Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes) sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Move back to cpu and unpack. # Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu() sizes_cpu = sizes_cuda.cpu()
...@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype): ...@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data) data)
# Pack on rank zero. # Pack on rank zero.
if get_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type. # Check that all keys have the same data type.
_check_data_types(keys, data, datatype) _check_data_types(keys, data, datatype)
# Flatten the data associated with the keys # Flatten the data associated with the keys
...@@ -100,9 +100,9 @@ def broadcast_data(keys, data, datatype): ...@@ -100,9 +100,9 @@ def broadcast_data(keys, data, datatype):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=datatype) dtype=datatype)
# Boradcast # Broadcast
torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Unpack # Unpack
output = {} output = {}
......
...@@ -28,8 +28,9 @@ try: ...@@ -28,8 +28,9 @@ try:
except Exception as e: except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.') print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
def l2_grad_clipper(parameters, max_norm): def l2_grad_clipper(parameters, max_norm):
...@@ -43,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -43,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads = list(filter( parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters)) lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations. # Filter parameters for norm calculations.
mp_rank_is_zero = (get_model_parallel_rank() == 0) mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter( parameters_for_norm = list(filter(
lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads)) lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm. # Calculate L2 norm.
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
...@@ -71,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -71,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm):
return total_norm return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2): def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
...@@ -90,13 +91,27 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -90,13 +91,27 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
""" """
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters)) if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs. # Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_model_parallel_group())
...@@ -105,16 +120,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -105,16 +120,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
if clip_coef < 1: if clip_coef < 1:
for p in parameters: for p in parameters:
p.grad.data.mul_(clip_coef) p.grad.data.mul_(clip_coef)
#elif norm_type == 2:
# total_norm = l2_grad_clipper(parameters, max_norm)
else: else:
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0): if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type) param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs. # Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
......
...@@ -21,75 +21,148 @@ import torch ...@@ -21,75 +21,148 @@ import torch
from .utils import ensure_divisibility from .utils import ensure_divisibility
# Model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_): def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
model_parallel_size: number of GPUs used to parallelize model. tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
create 4 model parallel groups and 2 data parallel grous as: use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
4 model parallel groups: the model pipeline. The present function will
[g0, g1], [g2, g3], [g4, g5], [g6, g7] create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
2 data parallel groups: and 8 data-parallel groups as:
[g0, g2, g4, g6], [g1, g3, g5, g7] 8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> initializing model parallel with size {}'.format( print('> initializing tensor model parallel with size {}'.format(
model_parallel_size_)) tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size) tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size,
tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data parallel groups. # Build the data-parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized' 'data parallel group is already initialized'
for i in range(model_parallel_size): all_data_parallel_group_ranks = []
ranks = range(i, world_size, model_parallel_size) for i in range(pipeline_model_parallel_size):
group = torch.distributed.new_group(ranks) start_rank = i * num_pipeline_model_parallel_groups
if i == (rank % model_parallel_size): end_rank = (i + 1) * num_pipeline_model_parallel_groups
_DATA_PARALLEL_GROUP = group for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank,
# Build the model parallel groups. tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \ assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized' 'model parallel group is already initialized'
for i in range(world_size // model_parallel_size): for i in range(data_parallel_size):
ranks = range(i * model_parallel_size, ranks = [data_parallel_group_ranks[i]
(i + 1) * model_parallel_size) for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size): if rank in ranks:
_MODEL_PARALLEL_GROUP = group _MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False return False
return True return True
...@@ -101,6 +174,20 @@ def get_model_parallel_group(): ...@@ -101,6 +174,20 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'pipeline_model parallel group is not initialized'
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group(): def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to.""" """Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \ assert _DATA_PARALLEL_GROUP is not None, \
...@@ -108,41 +195,97 @@ def get_data_parallel_group(): ...@@ -108,41 +195,97 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
def set_model_parallel_world_size(world_size): def get_embedding_group():
"""Set the model parallel size""" """Get the embedding group the caller rank belongs to."""
global _MPU_WORLD_SIZE assert _EMBEDDING_GROUP is not None, \
_MPU_WORLD_SIZE = world_size 'embedding group is not initialized'
return _EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global _MPU_WORLD_SIZE
if _MPU_WORLD_SIZE is not None:
return _MPU_WORLD_SIZE
return torch.distributed.get_world_size(group=get_model_parallel_group())
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def set_model_parallel_rank(rank):
"""Set model parallel rank."""
global _MPU_RANK
_MPU_RANK = rank
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
global _MPU_RANK
if _MPU_RANK is not None:
return _MPU_RANK
return torch.distributed.get_rank(group=get_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zeor def set_pipeline_model_parallel_rank(rank):
in the model parallel group.""" """Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage():
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage():
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
...@@ -156,7 +299,9 @@ def get_data_parallel_rank(): ...@@ -156,7 +299,9 @@ def get_data_parallel_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
...@@ -35,12 +35,12 @@ except Exception as e: ...@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!') 'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method, ...@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1): partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True weight.tensor_model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
...@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
the relevant chunk.""" the relevant chunk."""
weight.model_parallel = True weight.tensor_model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
...@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
with torch.no_grad(): with torch.no_grad():
...@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False self.scale_grad_by_freq = False
self.sparse = False self.sparse = False
self._weight = None self._weight = None
self.model_parallel_size = get_model_parallel_world_size() self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \ self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size( VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(), self.num_embeddings, get_tensor_model_parallel_rank(),
self.model_parallel_size) self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
...@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim=0, stride=1) partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
input_mask = (input_ < self.vocab_start_index) | \ input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index) (input_ >= self.vocab_end_index)
...@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq, self.norm_type, self.scale_grad_by_freq,
self.sparse) self.sparse)
# Mask the output embedding. # Mask the output embedding.
if self.model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0 output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel) output = reduce_from_tensor_model_parallel_region(output_parallel)
return output return output
...@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
...@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.model_parallel = True self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = stride self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
...@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
...@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
...@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
output_bias = None output_bias = None
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import torch import torch
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -23,11 +23,11 @@ def _reduce(input_): ...@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_model_parallel_world_size()==1: if get_tensor_model_parallel_world_size()==1:
return input_ return input_
# All-reduce. # All-reduce.
torch.distributed.all_reduce(input_, group=get_model_parallel_group()) torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_ return input_
...@@ -36,7 +36,7 @@ def _split(input_): ...@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size==1:
return input_ return input_
...@@ -45,7 +45,7 @@ def _split(input_): ...@@ -45,7 +45,7 @@ def _split(input_):
input_list = split_tensor_along_last_dim(input_, world_size) input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous() output = input_list[rank].contiguous()
return output return output
...@@ -54,18 +54,18 @@ def _split(input_): ...@@ -54,18 +54,18 @@ def _split(input_):
def _gather(input_): def _gather(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size==1:
return input_ return input_
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor. # Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous() output = torch.cat(tensor_list, dim=last_dim).contiguous()
...@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions. # Helper functions.
# ----------------- # -----------------
def copy_to_model_parallel_region(input_): def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_) return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_): def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_): def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
...@@ -28,9 +28,9 @@ from megatron import get_args ...@@ -28,9 +28,9 @@ from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
...@@ -45,8 +45,8 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -45,8 +45,8 @@ def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations.""" """Initializ the memory buffer for the checkpointed activations."""
args = get_args() args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \ per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \ assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers' 'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
...@@ -54,7 +54,7 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -54,7 +54,7 @@ def init_checkpointed_activations_memory_buffer():
dtype = torch.half dtype = torch.half
if not args.fp16: if not args.fp16:
dtype = torch.float dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \ assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.' 'checkpointed activations memory buffer is already allocated.'
...@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) data = tensor.view(-1)
partition_size = torch.numel(data) // get_model_parallel_world_size() partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_model_parallel_rank() start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] return data[start_index:end_index]
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor) numel = torch.numel(tensor)
numel_gathered = world_size * numel numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
...@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor): ...@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, torch.distributed.all_gather(chunks, tensor,
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
return gathered return gathered
...@@ -215,15 +215,15 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -215,15 +215,15 @@ def model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across set of model parallel GPUs but different across
different model paralle groups. This is used for different model paralle groups. This is used for
example for dropout in the non-model-parallel regions. example for dropout in the non-tensor-model-parallel regions.
model-parallel state: This state is different among a set of model tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in groups. This is used for example for dropout in
model parallel regions. model parallel regions.
""" """
# 2718 is just for fun and any POSITIVE value will work. # 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718 offset = seed + 2718
model_parallel_seed = offset + get_model_parallel_rank() tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
...@@ -231,15 +231,15 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -231,15 +231,15 @@ def model_parallel_cuda_manual_seed(seed):
print('> initializing model parallel cuda seeds on global rank {}, ' print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with ' 'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format( 'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_model_parallel_rank(), torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), model_parallel_seed, get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True) data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state. # and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
model_parallel_seed) tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
...@@ -268,11 +268,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -268,11 +268,11 @@ class CheckpointFunction(torch.autograd.Function):
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data) args[0].data)
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
return outputs return outputs
@staticmethod @staticmethod
......
...@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, ...@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity = IdentityLayer((batch_size, seq_length, vocab_size), identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda() scale=logits_scale).cuda()
logits = identity() logits = identity()
logits_parallel = mpu.scatter_to_model_parallel_region(logits) logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor( target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size) size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
...@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, ...@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return loss, identity.weight.grad return loss, identity.weight.grad
def test_cross_entropy(model_parallel_size): def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'. print('> testing cross entropy with model parallel size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 13 batch_size = 13
seq_length = 17 seq_length = 17
vocab_size_per_partition = 11 vocab_size_per_partition = 11
logits_scale = 1000.0 logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234 seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
...@@ -89,7 +89,7 @@ def test_cross_entropy(model_parallel_size): ...@@ -89,7 +89,7 @@ def test_cross_entropy(model_parallel_size):
assert error < 1.0e-6 assert error < 1.0e-6
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_tensor_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -101,8 +101,8 @@ if __name__ == '__main__': ...@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy') print_separator('test cross entropy')
test_cross_entropy(model_parallel_size) test_cross_entropy(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -24,15 +24,15 @@ import sys ...@@ -24,15 +24,15 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_boradcast_data(model_parallel_size): def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing boradcast_data with model parallel size {} ...'. print('> testing broadcast_data with model parallel size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank()) torch.manual_seed(1234 + mpu.get_data_parallel_rank())
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
key_size_t = {'key1': [7, 11], key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1], 'key2': [8, 2, 1],
...@@ -48,7 +48,7 @@ def test_boradcast_data(model_parallel_size): ...@@ -48,7 +48,7 @@ def test_boradcast_data(model_parallel_size):
data_t[key] = data[key].clone() data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone() data_t['keyX'] = data['keyX'].clone()
if mpu.get_model_parallel_rank() != 0: if mpu.get_tensor_model_parallel_rank() != 0:
data = None data = None
data_utils._check_data_types(keys, data_t, torch.int64) data_utils._check_data_types(keys, data_t, torch.int64)
...@@ -69,7 +69,7 @@ def test_boradcast_data(model_parallel_size): ...@@ -69,7 +69,7 @@ def test_boradcast_data(model_parallel_size):
assert data_b[key].sub(tensor).abs().max() == 0 assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_tensor_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -81,8 +81,8 @@ if __name__ == '__main__': ...@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test test boradcast data') print_separator('test test broadcast data')
test_boradcast_data(model_parallel_size) test_broadcast_data(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -21,15 +21,15 @@ import sys ...@@ -21,15 +21,15 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_initialize_model_parallel(model_parallel_size): def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format( print('> testing initialize_model_parallel with size {} ...'.format(
model_parallel_size)) tensor_model_parallel_size))
model_parallel_size_ = min(model_parallel_size, tensor_model_parallel_size_ = min(tensor_model_parallel_size,
torch.distributed.get_world_size()) torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized() assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_) mpu.initialize_model_parallel(tensor_model_parallel_size_)
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
# Checks. # Checks.
...@@ -38,15 +38,15 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -38,15 +38,15 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == torch.distributed.get_rank(group=group) assert rank == torch.distributed.get_rank(group=group)
# Model parallel. # Model parallel.
world_size = model_parallel_size_ world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size() assert world_size == mpu.get_tensor_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank() assert rank == mpu.get_tensor_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank) check(mpu.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel. # Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_ world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == mpu.get_data_parallel_world_size() assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank() assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank) check(mpu.get_data_parallel_group(), world_size, rank)
...@@ -59,20 +59,20 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -59,20 +59,20 @@ def test_initialize_model_parallel(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_get_model_parallel_src_rank(model_parallel_size_): def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing get_model_parallel_src_rank with size {} ...'.format( print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
model_parallel_size_)) tensor_model_parallel_size_))
model_parallel_size = min(model_parallel_size_, tensor_model_parallel_size = min(tensor_model_parallel_size_,
torch.distributed.get_world_size()) torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized() assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
# Checks # Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank assert mpu.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
...@@ -86,10 +86,10 @@ if __name__ == '__main__': ...@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel') print_separator('test initialize model parallel')
test_initialize_model_parallel(model_parallel_size) test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank') print_separator('test model parallel source rank')
test_get_model_parallel_src_rank(model_parallel_size) test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -26,14 +26,14 @@ import sys ...@@ -26,14 +26,14 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_parallel_embedding(model_parallel_size): def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'. print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 17 batch_size = 17
seq_length = 23 seq_length = 23
...@@ -80,16 +80,16 @@ def test_parallel_embedding(model_parallel_size): ...@@ -80,16 +80,16 @@ def test_parallel_embedding(model_parallel_size):
assert error < 1.0e-12, 'error: {}'.format(error) assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size, hidden_size // tensor_model_parallel_size,
1)[mpu.get_model_parallel_rank()] 1)[mpu.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format( print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error)) torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error) assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size, vocab_size // tensor_model_parallel_size,
0)[mpu.get_model_parallel_rank()] 0)[mpu.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub( error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max() weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format( print(' error in grad (vocab parallel) on global rank {}: {}'.format(
...@@ -104,19 +104,19 @@ def test_parallel_embedding(model_parallel_size): ...@@ -104,19 +104,19 @@ def test_parallel_embedding(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size): def test_initialize_affine_weight(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel ' print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
# --------------- # ---------------
# Column parallel # Column parallel
...@@ -131,7 +131,7 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -131,7 +131,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
master_weight = torch.empty(output_size, input_size) master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight) torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff, my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone() dim=0)[rank].contiguous().clone()
...@@ -154,7 +154,7 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -154,7 +154,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
master_weight = torch.empty(output_size, input_size) master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight) torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff, my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone() dim=1)[rank].contiguous().clone()
...@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module): ...@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight return self.weight
def test_column_parallel_linear(model_parallel_size): def test_column_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel ' print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
# Network # Network
...@@ -219,7 +219,7 @@ def test_column_parallel_linear(model_parallel_size): ...@@ -219,7 +219,7 @@ def test_column_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A) dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff, my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone() dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
...@@ -250,20 +250,20 @@ def test_column_parallel_linear(model_parallel_size): ...@@ -250,20 +250,20 @@ def test_column_parallel_linear(model_parallel_size):
print(' >> passed the test :-)') print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size): def test_row_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel ' print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
# Network # Network
...@@ -286,7 +286,7 @@ def test_row_parallel_linear(model_parallel_size): ...@@ -286,7 +286,7 @@ def test_row_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A) dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff, my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone() dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
...@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module): ...@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return self.weight return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length): sequence_length):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
...@@ -352,17 +352,17 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, ...@@ -352,17 +352,17 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
# Backward # Backward
loss.backward() loss.backward()
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \ return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size): def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel ' print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
...@@ -370,14 +370,14 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -370,14 +370,14 @@ def test_parallel_self_attention(model_parallel_size):
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention( attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention( attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition, tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size assert hideen_size_1 == hidden_size
...@@ -389,7 +389,7 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -389,7 +389,7 @@ def test_parallel_self_attention(model_parallel_size):
my_lin_grad_list = torch.split( my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad, attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size] hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0) my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub( error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max() attention_layer.query_key_value.weight.grad).abs().max()
...@@ -410,11 +410,11 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -410,11 +410,11 @@ def test_parallel_self_attention(model_parallel_size):
print(' >> passed the test :-)') print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition, def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length): hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
...@@ -440,31 +440,31 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition, ...@@ -440,31 +440,31 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
# Backward # Backward
loss.backward() loss.backward()
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \ return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size): def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel ' print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \ rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer( transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length) hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer( transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition, tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length) hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max() error = loss_1.sub(loss).abs().max()
...@@ -494,37 +494,37 @@ if __name__ == '__main__': ...@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight') print_separator('test initialize affine weight')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size) test_initialize_affine_weight(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test parallel embedding') print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size) test_parallel_embedding(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test column-parallel linear') print_separator('test column-parallel linear')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size) test_column_parallel_linear(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test row-parallel linear') print_separator('test row-parallel linear')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size) test_row_parallel_linear(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test parallel self-attention') print_separator('test parallel self-attention')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size) test_parallel_self_attention(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test parallel transformer') print_separator('test parallel transformer')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size) test_parallel_transformer_layer(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -21,14 +21,14 @@ import sys ...@@ -21,14 +21,14 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_set_cuda_rng_state(model_parallel_size): def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'. print('> testing set_rng_state with size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
size = 123 size = 123
seed = 1234 seed = 1234
...@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(model_parallel_size): ...@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size): def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'. print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed_1 = 1234 seed_1 = 1234
seed_2 = 4321 seed_2 = 4321
...@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(model_parallel_size): ...@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size): def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'. print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345) mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345 assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 + assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank()) mpu.get_tensor_model_parallel_rank())
# Reset the tracker # Reset the tracker
mpu.get_cuda_rng_tracker().reset() mpu.get_cuda_rng_tracker().reset()
...@@ -185,20 +185,20 @@ if __name__ == '__main__': ...@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test set rng state') print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size) test_set_cuda_rng_state(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker') print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size) test_cuda_rng_tracker(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed') print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size) test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -26,6 +26,7 @@ import torch.nn.functional as F ...@@ -26,6 +26,7 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.training import communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
...@@ -35,7 +36,7 @@ def get_batch(context_tokens): ...@@ -35,7 +36,7 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Move to GPU. # Move to GPU.
tokens = context_tokens.view(args.batch_size, -1).contiguous().cuda() tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids. # Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids( attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
...@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model): ...@@ -88,14 +89,14 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file. # Read the sample file and open the output file.
assert args.sample_input_file is not None, \ assert args.sample_input_file is not None, \
'sample input file is not provided.' 'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r") fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines() all_raw_text = fname.readlines()
input_count = len(all_raw_text) input_count = len(all_raw_text)
input_pos = 0 input_pos = 0
if args.sample_output_file is None: if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out" sample_output_file = args.sample_input_file + ".out"
print('could not find `sample-output-file`, setting ' print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file)) 'it to {}'.format(sample_output_file))
else: else:
sample_output_file = args.sample_output_file sample_output_file = args.sample_output_file
...@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model): ...@@ -105,14 +106,16 @@ def generate_samples_input_from_file(model):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs = 0 terminate_runs = 0
raw_text_len = 0
if mpu.get_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
input_pos += 1 input_pos += 1
if input_pos == input_count: if input_pos == input_count:
raw_text = "stop" raw_text = "stop"
raw_text_len = len(raw_text)
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
...@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model): ...@@ -127,38 +130,60 @@ def generate_samples_input_from_file(model):
continue continue
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) input_info = [terminate_runs, raw_text_len, context_length]
torch.distributed.broadcast(terminate_runs_tensor, input_info_tensor = torch.cuda.LongTensor(input_info)
mpu.get_model_parallel_src_rank(), torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
else:
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens pass
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') if mpu.is_pipeline_first_stage():
print("\nContext:", raw_text, flush=True) os.system('clear')
trim_decode_tokens = tokenizer.detokenize( print("\nContext:", raw_text, flush=True)
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:") fname_out.write("\nContext:")
fname_out.write(raw_text) fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
raw_text = None decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
torch.distributed.barrier(group=mpu.get_model_parallel_group()) raw_text = None
context_count += 1 context_count += 1
...@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -171,15 +196,17 @@ def generate_samples_interactive(model, print_frequency=24):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
while True: while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs = 0 terminate_runs = 0
raw_text_len = 0
if mpu.get_model_parallel_rank() == 0: if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
while not raw_text: while not raw_text:
print('Prompt should not be empty!') print('Prompt should not be empty!')
raw_text = input("\nContext prompt (stop to exit) >>> ") raw_text = input("\nContext prompt (stop to exit) >>> ")
raw_text_len = len(raw_text)
if "stop" in raw_text: if "stop" in raw_text:
terminate_runs = 1 terminate_runs = 1
...@@ -194,43 +221,71 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -194,43 +221,71 @@ def generate_samples_interactive(model, print_frequency=24):
continue continue
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens) context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) input_info = [terminate_runs, raw_text_len, context_length]
torch.distributed.broadcast(terminate_runs_tensor, input_info_tensor = torch.cuda.LongTensor(input_info)
mpu.get_model_parallel_src_rank(), torch.distributed.all_reduce(input_info_tensor,
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item() terminate_runs = input_info_tensor[0].item()
raw_text_len = input_info_tensor[1].item()
context_length = input_info_tensor[2].item()
if terminate_runs == 1: if terminate_runs == 1:
return return
# For pipeline parallel we send context tokens to other stages
# so they get the lengths correct
if mpu.get_tensor_model_parallel_rank() == 0 \
and args.pipeline_model_parallel_size > 1:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
torch.distributed.broadcast(context_tokens_tensor, src, group)
else:
src = mpu.get_pipeline_model_parallel_first_rank()
group = mpu.get_pipeline_model_parallel_group()
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device("cuda"))
torch.distributed.broadcast(context_tokens_tensor, src, group)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
if counter % print_frequency != 0 \
or mpu.get_tensor_model_parallel_rank() != 0 \
or not mpu.is_pipeline_first_stage():
continue
os.system('clear')
print("\nContext:", raw_text, flush=True)
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0 and \ if mpu.is_pipeline_first_stage() \
counter % print_frequency == 0: and mpu.get_tensor_model_parallel_rank() == 0:
os.system('clear')
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0:
os.system('clear') os.system('clear')
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
if not isinstance(decode_tokens, list):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):] decode_tokens)[raw_text_len:]
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
input("\nPress Enter to continue >>>")
raw_text = None raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1 context_count += 1
if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model): def generate_samples_unconditional(model):
...@@ -240,29 +295,38 @@ def generate_samples_unconditional(model): ...@@ -240,29 +295,38 @@ def generate_samples_unconditional(model):
num_samples = args.num_samples num_samples = args.num_samples
context_tokens = [[tokenizer.eod] context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)] for _ in range(args.micro_batch_size)]
ctr = 0 ctr = 0
while True: while True:
start_time = time.time() start_time = time.time()
for token_stream in get_token_stream(model, for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)): copy.deepcopy(context_tokens)):
pass pass
if ctr % args.log_interval == 0: if mpu.is_pipeline_last_stage() and \
print('Avg s/batch:', mpu.get_tensor_model_parallel_rank() == 0:
(time.time() - start_time) / min(args.log_interval, ctr + 1)) if ctr % args.log_interval == 0:
start_time = time.time() print('Avg s/batch:',
length = len(token_stream) (time.time() - start_time) / min(args.log_interval, ctr + 1))
token_batch = token_stream[0].cpu().numpy().tolist() start_time = time.time()
length_batch = token_stream[1].cpu().numpy().tolist() length = len(token_stream)
for tokens, length in zip(token_batch, length_batch): token_batch = token_stream[0].cpu().numpy().tolist()
tokens = tokens[1:length - 1] length_batch = token_stream[1].cpu().numpy().tolist()
text = tokenizer.detokenize(tokens) assert len(length_batch) == args.micro_batch_size
is_finished = length < args.seq_length - 1 for tokens, length in zip(token_batch, length_batch):
datum = {'text': text, 'length': length - 1, 'finished': is_finished} tokens = tokens[1:length - 1]
yield datum text = tokenizer.detokenize(tokens)
ctr += 1 is_finished = length < args.seq_length - 1
if ctr >= num_samples: datum = {'text': text, 'length': length - 1, 'finished': is_finished}
break yield datum
ctr += 1
if ctr >= num_samples:
break
else:
for _ in range(args.micro_batch_size):
yield None
ctr += 1
if ctr >= num_samples:
break
if ctr >= num_samples: if ctr >= num_samples:
break break
...@@ -273,7 +337,9 @@ def generate_and_write_samples_unconditional(model): ...@@ -273,7 +337,9 @@ def generate_and_write_samples_unconditional(model):
assert args.genfile is not None assert args.genfile is not None
with open(args.genfile, 'w') as f: with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model): for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum) + '\n') if mpu.is_pipeline_last_stage() and \
mpu.get_tensor_model_parallel_rank() == 0:
f.write(json.dumps(datum) + '\n')
def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, args):
...@@ -299,11 +365,11 @@ def get_token_stream(model, context_tokens): ...@@ -299,11 +365,11 @@ def get_token_stream(model, context_tokens):
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor, torch.distributed.broadcast(context_length_tensor,
mpu.get_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
...@@ -313,7 +379,10 @@ def get_token_stream(model, context_tokens): ...@@ -313,7 +379,10 @@ def get_token_stream(model, context_tokens):
attention_mask, position_ids) attention_mask, position_ids)
for tokens, lengths in batch_token_iterator: for tokens, lengths in batch_token_iterator:
context_length += 1 context_length += 1
yield tokens[:, :context_length], lengths if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
...@@ -322,6 +391,66 @@ def switch(val1, val2, boolean): ...@@ -322,6 +391,66 @@ def switch(val1, val2, boolean):
return (1 - boolean) * val1 + boolean * val2 return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell communicate()
# the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
output_tensor, layer_past = output_tensor
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
args.seq_length = orig_seq_length
if get_key_value:
return output_tensor, layer_past
return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
maxlen=None, type_ids=None): maxlen=None, type_ids=None):
...@@ -349,14 +478,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -349,14 +478,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, output = forward_step(model, tokens,
position_ids, position_ids,
attention_mask, attention_mask,
tokentype_ids=type_ids, tokentype_ids=type_ids,
forward_method_parallel_output=False) forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :] if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, context_length - 1, :]
else: else:
types2use = None types2use = None
if counter == 0: if counter == 0:
...@@ -372,41 +502,65 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -372,41 +502,65 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if type_ids is not None: if type_ids is not None:
types2use = type_ids[:, context_length - 1].view( types2use = type_ids[:, context_length - 1].view(
batch_size, -1) batch_size, -1)
logits, layer_past = model(tokens2use, output, layer_past = forward_step(model, tokens2use,
positions2use, positions2use,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=True, get_key_value=True,
tokentype_ids=types2use, tokentype_ids=types2use,
forward_method_parallel_output=False) forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage():
assert output is not None
if args.greedy: logits = output[:, -1].view(batch_size, -1).contiguous()
prev = torch.argmax(logits, dim=-1).view(-1)
if mpu.is_pipeline_last_stage():
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
else: else:
logits = logits.float() if mpu.is_pipeline_first_stage():
logits /= args.temperature src = mpu.get_pipeline_model_parallel_last_rank()
logits = top_k_logits(logits, top_k=args.top_k, group = mpu.get_embedding_group()
top_p=args.top_p) new_tokens = torch.empty_like(tokens[:, context_length])
log_probs = F.softmax(logits, dim=-1) torch.distributed.broadcast(new_tokens, src, group)
prev = torch.multinomial(log_probs, num_samples=1).view(-1) tokens[:, context_length] = new_tokens
yield tokens, None
print_logits = [] else:
for p in prev: yield None, None
print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length
tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1
counter += 1
done_token = (prev == eos_id).byte() & started.byte() done = torch.cuda.ByteTensor([0])
just_finished = (done_token & ~is_done).bool() src = mpu.get_pipeline_model_parallel_last_rank()
lengths[just_finished.view(-1)] = context_length group = mpu.get_pipeline_model_parallel_group()
is_done = is_done | done_token torch.distributed.broadcast(done, src, group)
done = torch.all(is_done)
yield tokens, lengths context_length += 1
counter += 1
if done: if done:
break break
...@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args): ...@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \ multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size args.tensor_model_parallel_size
while (after % multiple) != 0: while (after % multiple) != 0:
after += 1 after += 1
if args.rank == 0: if args.rank == 0:
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
from datetime import datetime from datetime import datetime
import math import math
import sys import sys
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
...@@ -25,13 +29,19 @@ from apex.optimizers import FusedAdam as Adam ...@@ -25,13 +29,19 @@ from apex.optimizers import FusedAdam as Adam
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
...@@ -41,6 +51,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader ...@@ -41,6 +51,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}): forward_step_func, extra_args_provider=None, args_defaults={}):
"""Main training program. """Main training program.
...@@ -71,6 +88,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -71,6 +88,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
initialize_megatron(extra_args_provider=extra_args_provider, initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -78,6 +107,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -78,6 +107,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop() timers('model and optimizer').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff. # Data stuff.
timers('train/valid/test data iterators').start() timers('train/valid/test data iterators').start()
...@@ -85,6 +116,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -85,6 +116,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
= build_train_valid_test_data_iterators( = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider) train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop() timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are built')
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setups ...')
...@@ -96,6 +128,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -96,6 +128,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
...@@ -113,6 +146,35 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -113,6 +146,35 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, True)
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func): def get_model(model_provider_func):
"""Build the model.""" """Build the model."""
...@@ -123,8 +185,10 @@ def get_model(model_provider_func): ...@@ -123,8 +185,10 @@ def get_model(model_provider_func):
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format( print(' > number of parameters on (tensor, pipeline) '
mpu.get_model_parallel_rank(), 'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True) sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation. # GPU allocation.
...@@ -134,7 +198,6 @@ def get_model(model_provider_func): ...@@ -134,7 +198,6 @@ def get_model(model_provider_func):
if args.fp16: if args.fp16:
model = FP16_Module(model) model = FP16_Module(model)
# Wrap model for distributed training."""
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i, model = torchDDP(model, device_ids=[i], output_device=i,
...@@ -160,8 +223,8 @@ def get_optimizer(model): ...@@ -160,8 +223,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set. # Add model parallel attribute if it is not set.
for param_group in param_groups: for param_group in param_groups:
for param in param_group['params']: for param in param_group['params']:
if not hasattr(param, 'model_parallel'): if not hasattr(param, 'tensor_model_parallel'):
param.model_parallel = False param.tensor_model_parallel = False
# Use Adam. # Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay, optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
...@@ -184,22 +247,39 @@ def get_learning_rate_scheduler(optimizer): ...@@ -184,22 +247,39 @@ def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
# Add linear learning rate scheduler. # Iteration-based training.
if args.lr_decay_iters is not None: if args.train_iters:
num_iters = args.lr_decay_iters if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
else:
warmup_steps = args.lr_warmup_samples
else: else:
num_iters = args.train_iters raise Exception(
num_iters = max(1, num_iters) 'either train-iters or train-samples should be provided.')
init_step = 0
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR( lr_scheduler = AnnealingLR(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_iter, warmup_steps=warmup_steps,
decay_steps=num_iters, decay_steps=decay_steps,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
num_steps=init_step,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_lr_scheduler=args.override_lr_scheduler)
...@@ -215,10 +295,22 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -215,10 +295,22 @@ def setup_model_and_optimizer(model_provider_func):
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None: if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('load checkpoint').stop()
timers.log(['load checkpoint'])
else: else:
args.iteration = 0 args.iteration = 0
# We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model unwrapped_model = model
while hasattr(unwrapped_model, 'module'): while hasattr(unwrapped_model, 'module'):
...@@ -232,26 +324,304 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -232,26 +324,304 @@ def setup_model_and_optimizer(model_provider_func):
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss): def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=args.params_dtype)
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=args.params_dtype)
# Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
return tensor_recv_prev, tensor_recv_next
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass. # Backward pass.
timers('backward-backward').start()
optimizer.zero_grad(set_grads_to_None=True)
if args.fp16: if args.fp16:
optimizer.backward(loss, update_master_grads=False) optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad)
else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
else: else:
loss.backward() input_tensor = None
timers('backward-backward').stop()
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
timers('forward-send').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
timers('backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('backward-recv').stop()
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
timers('backward-send').stop()
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_microbatch,
input_tensors, output_tensors,
losses_reduced, timers):
args = get_args()
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('forward-send-backward-recv').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send-forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=(not last_microbatch),
recv_backward=False)
timers('backward-send-forward-recv').stop()
else:
input_tensor = None
return input_tensor
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run forward and backward passes without inter-stage communication."""
args = get_args()
losses_reduced = []
for i in range(get_num_microbatches()):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
timers('backward-compute').start()
output_tensor_grad = None
backward_step(optimizer, model, input_tensor=None,
output_tensor=output_tensor, output_tensor_grad=None)
timers('backward-compute').stop()
return losses_reduced
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args = get_args()
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
mpu.get_pipeline_model_parallel_rank() - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
timers('forward-recv').start()
input_tensor, _ = communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
# Run 1F1B.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_iteration,
input_tensors, output_tensors,
losses_reduced, timers)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
return losses_reduced
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
else:
losses_reduced = forward_backward_no_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-allreduce').start() timers('backward-params-all-reduce').start()
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('backward-allreduce').stop() timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update master gradients. # Update master gradients.
timers('backward-master-grad').start() timers('backward-master-grad').start()
...@@ -261,30 +631,20 @@ def backward_step(optimizer, model, loss): ...@@ -261,30 +631,20 @@ def backward_step(optimizer, model, loss):
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start() timers('backward-clip-grad').start()
if args.clip_grad > 0: if args.clip_grad > 0.:
if not args.fp16: if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad) named_parameters = model.named_parameters()
parameters = []
parameter_names = []
for parameter_name, parameter in model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop() timers('backward-clip-grad').stop()
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Forward model for one step.
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss)
timers('backward').stop()
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
...@@ -293,11 +653,21 @@ def train_step(forward_step_func, data_iterator, ...@@ -293,11 +653,21 @@ def train_step(forward_step_func, data_iterator,
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
if not (args.fp16 and optimizer.overflow): if not (args.fp16 and optimizer.overflow):
lr_scheduler.step() increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
else: else:
skipped_iter = 1 skipped_iter = 1
return loss_reduced, skipped_iter if mpu.is_pipeline_last_stage():
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter
return {}, skipped_iter
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...@@ -307,12 +677,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -307,12 +677,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers = get_timers() timers = get_timers()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
# Update losses. # Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations' skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
# Advanced iterations.
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
# Skipped iterations.
total_loss_dict[skipped_iters_key] = total_loss_dict.get( total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter skipped_iters_key, 0) + skipped_iter
got_nan_key = 'got nan' # Update losses and set nan iterations
got_nan = False got_nan = False
for key in loss_dict: for key in loss_dict:
if not skipped_iter: if not skipped_iter:
...@@ -324,9 +703,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -324,9 +703,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
value == -float('inf') or \ value == -float('inf') or \
value != value value != value
got_nan = got_nan or is_nan got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
total_loss_dict[got_nan_key] = total_loss_dict.get( nan_iters_key, 0) + int(got_nan)
got_nan_key, 0) + int(got_nan)
# Logging. # Logging.
timers_to_log = [] timers_to_log = []
...@@ -334,43 +712,66 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -334,43 +712,66 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def add_to_logging(name): def add_to_logging(name):
if name in timers.timers: if name in timers.timers:
timers_to_log.append(name) timers_to_log.append(name)
add_to_logging('forward') add_to_logging('forward-compute')
add_to_logging('backward') add_to_logging('forward-recv')
add_to_logging('backward-backward') add_to_logging('forward-send')
add_to_logging('backward-allreduce') add_to_logging('forward-send-backward-recv')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-send-forward-recv')
add_to_logging('backward-master-grad') add_to_logging('backward-master-grad')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad') add_to_logging('backward-clip-grad')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch generator') add_to_logging('batch-generator')
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# Tensorboard values. # Tensorboard values.
if writer and torch.distributed.get_rank() == 0: if writer and is_last_rank():
writer.add_scalar('learning_rate', learning_rate, iteration) writer.add_scalar('learning-rate', learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
for key in loss_dict: for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration) writer.add_scalar(key , loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if args.fp16: if args.fp16:
writer.add_scalar('loss_scale', loss_scale, iteration) writer.add_scalar('loss-scale', loss_scale, iteration)
normalizer = iteration % args.log_interval writer.add_scalar('loss-scale vs samples', loss_scale,
if normalizer == 0: args.consumed_train_samples)
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=normalizer) normalizer=total_iterations)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed() elapsed_time = timers('interval time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time', writer.add_scalar('iteration-time',
elapsed_time / args.log_interval, iteration) elapsed_time_per_iteration, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration, log_string = ' iteration {:8d}/{:8d} |'.format(
args.train_iters) iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval) elapsed_time_per_iteration * 1000.0)
log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' learning rate: {:.3E} |'.format(learning_rate)
num_iterations = max( log_string += ' global batch size: {:5d} |'.format(batch_size)
1, args.log_interval - total_loss_dict[skipped_iters_key])
for key in total_loss_dict: for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]: if key not in [advanced_iters_key, skipped_iters_key,
avg = total_loss_dict[key].item() / float(num_iterations) nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0: if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
...@@ -379,24 +780,41 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -379,24 +780,41 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key]) total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format( log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[got_nan_key]) total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0 total_loss_dict[nan_iters_key] = 0
print_rank_0(log_string) print_rank_last(log_string)
if report_memory_flag: if report_memory_flag and learning_rate > 0.:
report_memory('after {} iterations'.format(iteration)) # Report memory after optimizer state has been initialized.
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval) timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('save checkpoint').stop()
timers.log(['save checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator): train_data_iterator, valid_data_iterator):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Write args to tensorboard
write_args_to_tensorboard()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() model.train()
...@@ -407,8 +825,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -407,8 +825,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration = args.iteration iteration = args.iteration
timers('interval time').start() timers('interval time').start()
print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
...@@ -416,7 +836,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -416,7 +836,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler) lr_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.batch_size args.micro_batch_size * \
get_num_microbatches()
# Logging. # Logging.
loss_scale = None loss_scale = None
...@@ -434,9 +855,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -434,9 +855,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler) lr_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -446,14 +871,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -446,14 +871,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') print_datetime('exiting program at iteration {}'.format(iteration))
rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at '
'iteration {}'.format(rank, time_str, iteration))
sys.exit() sys.exit()
return iteration return iteration
...@@ -473,23 +915,44 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -473,23 +915,44 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if verbose and iteration % args.log_interval == 0: if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
# Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model) for _ in range(get_num_microbatches()):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward evaluation.
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_pipeline_last_stage():
_, loss_dict = output_tensor
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
loss_dict[key]
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.batch_size * args.micro_batch_size \
# Reduce across processes. * get_num_microbatches()
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key]
# Move model back to the train mode. # Move model back to the train mode.
model.train() model.train()
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict return total_loss_dict
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
iteration, verbose=False): iteration, verbose=False):
...@@ -509,9 +972,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -509,9 +972,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} ppl'.format(key), ppl, iteration) writer.add_scalar('{} ppl'.format(key), ppl, iteration)
length = len(string) + 1 length = len(string) + 1
print_rank_0('-' * length) print_rank_last('-' * length)
print_rank_0(string) print_rank_last(string)
print_rank_0('-' * length) print_rank_last('-' * length)
def build_train_valid_test_data_iterators( def build_train_valid_test_data_iterators(
...@@ -523,26 +986,31 @@ def build_train_valid_test_data_iterators( ...@@ -523,26 +986,31 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...') print_rank_0('> building train, validation, and test datasets ...')
# Rank and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
# Backward compatibility, assume fixed batch size. # Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0: if args.iteration > 0 and args.consumed_train_samples == 0:
args.consumed_train_samples = args.iteration * global_batch_size assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0: if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * global_batch_size args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples. # Number of train/valid/test samples.
train_iters = args.train_iters if args.train_samples:
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size, train_val_test_num_samples = [train_samples,
eval_iters * global_batch_size, eval_iters * args.global_batch_size,
test_iters * global_batch_size] test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
...@@ -571,12 +1039,12 @@ def build_train_valid_test_data_iterators( ...@@ -571,12 +1039,12 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(flags, torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
args.do_train = flags[0].item() args.do_train = flags[0].item()
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
# Build iterators. # Build iterators.
if train_dataloader is not None: if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) train_data_iterator = iter(train_dataloader)
......
...@@ -27,14 +27,16 @@ from megatron.checkpointing import save_checkpoint ...@@ -27,14 +27,16 @@ from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
def reduce_losses(losses): def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs.""" """Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat( averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses) torch.distributed.all_reduce(averaged_losses,
reduced_losses = reduced_losses / torch.distributed.get_world_size() group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return reduced_losses return averaged_losses
def report_memory(name): def report_memory(name):
...@@ -48,14 +50,15 @@ def report_memory(name): ...@@ -48,14 +50,15 @@ def report_memory(name):
string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes) string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format( string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes) torch.cuda.max_memory_reserved() / mega_bytes)
print_rank_0(string) if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def print_params_min_max_norm(optimizer, iteration): def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters.""" """Print min, max, and norm of all parameters."""
index = 0 index = 0
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n' string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer): if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer optimizer_ = optimizer.optimizer
...@@ -66,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -66,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max() max_ = param.data.max()
norm = torch.linalg.norm(param.data) norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel)) iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True) print(string, flush=True)
...@@ -96,11 +99,11 @@ def get_ltor_masks_and_position_ids(data, ...@@ -96,11 +99,11 @@ def get_ltor_masks_and_position_ids(data,
"""Build masks and position id for left to right model.""" """Build masks and position id for left to right model."""
# Extract batch size and sequence length. # Extract batch size and sequence length.
batch_size, seq_length = data.size() micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular). # Attention mask (lower triangular).
if reset_attention_mask: if reset_attention_mask:
att_mask_batch = batch_size att_mask_batch = micro_batch_size
else: else:
att_mask_batch = 1 att_mask_batch = 1
attention_mask = torch.tril(torch.ones( attention_mask = torch.tril(torch.ones(
...@@ -122,7 +125,7 @@ def get_ltor_masks_and_position_ids(data, ...@@ -122,7 +125,7 @@ def get_ltor_masks_and_position_ids(data,
if reset_position_ids or reset_attention_mask: if reset_position_ids or reset_attention_mask:
# Loop through the batches: # Loop through the batches:
for b in range(batch_size): for b in range(micro_batch_size):
# Find indecies where EOD token is. # Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token] eod_index = position_ids[b, data[b] == eod_token]
......
...@@ -23,9 +23,9 @@ from megatron import print_rank_0 ...@@ -23,9 +23,9 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
...@@ -33,10 +33,25 @@ def model_provider(): ...@@ -33,10 +33,25 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel( args = get_args()
num_tokentypes=2, if mpu.get_pipeline_model_parallel_world_size() > 1:
add_binary_head=True, # Determine model based on position of stage in pipeline.
parallel_output=True) if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
return model return model
...@@ -66,34 +81,51 @@ def get_batch(data_iterator): ...@@ -66,34 +81,51 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
# Forward model. lm_labels if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = model(tokens, padding_mask, lm_loss_, sop_logits = output_tensor
tokentype_ids=types,
lm_labels=lm_labels)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1), sentence_order.view(-1),
ignore_index=-1) ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss = torch.sum( lm_loss_ = lm_loss_.float()
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss]) averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]} return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -23,16 +23,28 @@ from megatron import get_timers ...@@ -23,16 +23,28 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True) args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model return model
...@@ -69,25 +81,42 @@ def get_batch(data_iterator): ...@@ -69,25 +81,42 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator) data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels) # Forward pass through the model.
loss_mask = loss_mask.view(-1) if mpu.is_pipeline_first_stage():
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging. # Reduce loss for logging.
reduced_loss = reduce_losses([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -25,12 +25,13 @@ from megatron import get_timers ...@@ -25,12 +25,13 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args()
return general_ict_model_provider(False, False) return general_ict_model_provider(False, False)
...@@ -72,22 +73,22 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,22 +73,22 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output return output
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1 global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
...@@ -102,11 +103,12 @@ def forward_step(data_iterator, model): ...@@ -102,11 +103,12 @@ def forward_step(data_iterator, model):
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *topk_accs]) retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])} topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict) stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
return retrieval_loss, stats_dict return retrieval_loss, stats_dict
......
...@@ -21,8 +21,9 @@ import time ...@@ -21,8 +21,9 @@ import time
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_last, is_last_rank
from megatron import mpu from megatron import mpu
from megatron.training import communicate
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
...@@ -37,12 +38,12 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -37,12 +38,12 @@ def accuracy_func_provider(single_dataset_provider):
for datapath in datapaths: for datapath in datapaths:
dataset = single_dataset_provider(datapath) dataset = single_dataset_provider(datapath)
dataloader = build_data_loader( dataloader = build_data_loader(
dataset, args.batch_size, num_workers=args.num_workers, dataset, args.micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1)) drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader)) dataloaders.append((dataset.dataset_name, dataloader))
def metrics_func(model, epoch, output_predictions=False): def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics ...') print_rank_last('calculating metrics ...')
correct = 0 correct = 0
total = 0 total = 0
if output_predictions: if output_predictions:
...@@ -60,25 +61,26 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -60,25 +61,26 @@ def accuracy_func_provider(single_dataset_provider):
names += '_' + name names += '_' + name
correct += correct_ans correct += correct_ans
total += total_count total += total_count
percent = float(correct) * 100.0 / float(total) if is_last_rank():
print_rank_0(' >> |epoch: {}| overall: correct / total = {} / {} = ' percent = float(correct) * 100.0 / float(total)
'{:.4f} %'.format(epoch, correct, total, percent)) print(' >> |epoch: {}| overall: correct / total = {} / {} = '
'{:.4f} %'.format(epoch, correct, total, percent))
if output_predictions and torch.distributed.get_rank() == 0: if output_predictions and is_last_rank():
assert args.load is not None assert args.load is not None
filename = os.path.join(args.load, names + '.pt') filename = os.path.join(args.load, names + '.pt')
torch.save(named_predictions, filename) torch.save(named_predictions, filename)
return metrics_func return metrics_func
def calculate_correct_answers(name, model, dataloader, def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions): epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the """Calculate correct over total answers and return prediction if the
`output_predictions` is true.""" `output_predictions` is true."""
args = get_args()
start_time = time.time() start_time = time.time()
model.eval() model.eval()
saved_batch_size = args.micro_batch_size
with torch.no_grad(): with torch.no_grad():
# For all the batches in the dataset. # For all the batches in the dataset.
total = 0 total = 0
...@@ -92,36 +94,79 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -92,36 +94,79 @@ def calculate_correct_answers(name, model, dataloader,
for _, batch in enumerate(dataloader): for _, batch in enumerate(dataloader):
# Run the model forward. # Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch) tokens, types, labels_, attention_mask = process_batch(batch)
logits = model(tokens, attention_mask, types)
# Add output predictions. # For evaluation only mode we use drop_last = False to get all the
if output_predictions: # samples, which means we might not have a full batch, so we
softmaxes.extend(torch.nn.Softmax(dim=-1)( # adjust batch_size here to actual batch size of data
logits.float()).data.cpu().numpy().tolist()) actual_batch_size = len(labels_)
labels.extend(labels_.data.cpu().numpy().tolist()) # ... applying sample_multiplier if necessary
ids.extend(batch['uid'].cpu().numpy().tolist()) ds = dataloader.dataset
# Compute the correct answers. if hasattr(ds, 'sample_multiplier'):
predicted = torch.argmax(logits, dim=-1) actual_batch_size *= ds.sample_multiplier
corrects = (predicted == labels_) args.micro_batch_size = actual_batch_size
# Add to the counters.
total += labels_.size(0) if not mpu.is_pipeline_first_stage():
correct += corrects.sum().item() input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
# Add output predictions.
if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist())
labels.extend(labels_.data.cpu().numpy().tolist())
ids.extend(batch['uid'].cpu().numpy().tolist())
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters.
total += labels_.size(0)
correct += corrects.sum().item()
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
model.train() model.train()
args.micro_batch_size = saved_batch_size
# Reduce. # Reduce.
unreduced = torch.cuda.LongTensor([correct, total]) if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(unreduced, unreduced = torch.cuda.LongTensor([correct, total])
group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item() # Print on screen.
total_count = unreduced[1].item()
percent = float(correct_ans) * 100.0 / float(total_count) correct_ans = unreduced[0].item()
elapsed_time = time.time() - start_time total_count = unreduced[1].item()
print_rank_0(' > |epoch: {}| metrics for {}: correct / total ' percent = float(correct_ans) * 100.0 / float(total_count)
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( elapsed_time = time.time() - start_time
epoch, name, correct_ans, total_count, print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
percent, elapsed_time)) '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
epoch, name, correct_ans, total_count,
percent, elapsed_time))
if output_predictions:
return correct_ans, total_count, (softmaxes, labels, ids)
return correct_ans, total_count
if output_predictions: if output_predictions:
return correct_ans, total_count, (softmaxes, labels, ids) return 0, 0, ()
return correct_ans, total_count return 0, 0
...@@ -28,7 +28,7 @@ from megatron.training import setup_model_and_optimizer ...@@ -28,7 +28,7 @@ from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
from megatron.training import training_log from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch): def process_batch(batch):
...@@ -45,33 +45,42 @@ def process_batch(batch): ...@@ -45,33 +45,42 @@ def process_batch(batch):
return tokens, types, labels, attention_mask return tokens, types, labels, attention_mask
def _cross_entropy_forward_step(batch, model): def _cross_entropy_forward_step(batch, model, input_tensor):
"""Simple forward step with cross-entropy loss.""" """Simple forward step with cross-entropy loss."""
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
try: try:
batch_ = next(batch) batch_ = next(batch)
except BaseException: except BaseException:
batch_ = batch batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_) tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model. # Forward model.
logits = model(tokens, attention_mask, types) if mpu.is_pipeline_first_stage():
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
# Cross-entropy loss. # Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss() loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels) loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging. # Reduce loss for logging.
reduced_loss = reduce_losses([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def build_data_loader(dataset, batch_size, num_workers, drop_last): def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.""" """Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler. # Sampler.
...@@ -82,7 +91,7 @@ def build_data_loader(dataset, batch_size, num_workers, drop_last): ...@@ -82,7 +91,7 @@ def build_data_loader(dataset, batch_size, num_workers, drop_last):
# Data loader. Note that batch size is the per GPU batch size. # Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset, data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, batch_size=micro_batch_size,
sampler=sampler, sampler=sampler,
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
...@@ -109,17 +118,26 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -109,17 +118,26 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0('building train and validation dataloaders ...') print_rank_0('building train and validation dataloaders ...')
# Training dataset. # Training dataset.
train_dataloader = build_data_loader(train_dataset, args.batch_size, train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last)
# Set the training iterations. # Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader) args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up # Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop. # shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.batch_size, valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
if hasattr(train_dataset, 'sample_multiplier'):
args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
...@@ -211,6 +229,8 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -211,6 +229,8 @@ def finetune(train_valid_datasets_provider, model_provider,
train_dataset, valid_dataset = train_valid_datasets_provider() train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset) train_dataset, valid_dataset)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop() timers('train/valid/test dataset/dataloder').stop()
# Build calback function. # Build calback function.
...@@ -255,5 +275,4 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -255,5 +275,4 @@ def finetune(train_valid_datasets_provider, model_provider,
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1') print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, epoch=-1, output_predictions=True) end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0('done :-)') print_rank_0('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment