Commit 9b174da8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Merge branch 'pipeline_parallel_main' into 'main'

Pipeline parallelism and inter-layer model parallelism implementation

See merge request ADLR/megatron-lm!159
parents 3aacd955 6e83649f
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 4
...@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes = [0 for _ in range(max_dim) for _ in keys] sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero. # Pack the sizes on rank zero.
if get_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
offset = 0 offset = 0
for key in keys: for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
...@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data): ...@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast. # Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes) sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Move back to cpu and unpack. # Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu() sizes_cpu = sizes_cuda.cpu()
...@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype): ...@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data) data)
# Pack on rank zero. # Pack on rank zero.
if get_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type. # Check that all keys have the same data type.
_check_data_types(keys, data, datatype) _check_data_types(keys, data, datatype)
# Flatten the data associated with the keys # Flatten the data associated with the keys
...@@ -100,9 +100,9 @@ def broadcast_data(keys, data, datatype): ...@@ -100,9 +100,9 @@ def broadcast_data(keys, data, datatype):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=datatype) dtype=datatype)
# Boradcast # Broadcast
torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
# Unpack # Unpack
output = {} output = {}
......
...@@ -28,8 +28,9 @@ try: ...@@ -28,8 +28,9 @@ try:
except Exception as e: except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.') print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
def l2_grad_clipper(parameters, max_norm): def l2_grad_clipper(parameters, max_norm):
...@@ -43,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -43,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads = list(filter( parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters)) lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations. # Filter parameters for norm calculations.
mp_rank_is_zero = (get_model_parallel_rank() == 0) mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter( parameters_for_norm = list(filter(
lambda p: p.model_parallel or mp_rank_is_zero, parameters_with_grads)) lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm. # Calculate L2 norm.
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, amp_C.multi_tensor_l2norm,
...@@ -71,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm): ...@@ -71,7 +72,7 @@ def l2_grad_clipper(parameters, max_norm):
return total_norm return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2): def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters. """Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
...@@ -90,13 +91,27 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -90,13 +91,27 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
""" """
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters)) if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs. # Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=get_model_parallel_group())
...@@ -105,16 +120,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -105,16 +120,13 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
if clip_coef < 1: if clip_coef < 1:
for p in parameters: for p in parameters:
p.grad.data.mul_(clip_coef) p.grad.data.mul_(clip_coef)
#elif norm_type == 2:
# total_norm = l2_grad_clipper(parameters, max_norm)
else: else:
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0): if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type) param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs. # Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
......
...@@ -21,75 +21,148 @@ import torch ...@@ -21,75 +21,148 @@ import torch
from .utils import ensure_divisibility from .utils import ensure_divisibility
# Model parallel group that the current rank belongs to. # Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_): def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
model_parallel_size: number of GPUs used to parallelize model. tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
create 4 model parallel groups and 2 data parallel grous as: use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
4 model parallel groups: the model pipeline. The present function will
[g0, g1], [g2, g3], [g4, g5], [g6, g7] create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
2 data parallel groups: and 8 data-parallel groups as:
[g0, g2, g4, g6], [g1, g3, g5, g7] 8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> initializing model parallel with size {}'.format( print('> initializing tensor model parallel with size {}'.format(
model_parallel_size_)) tensor_model_parallel_size_))
print('> initializing pipeline model parallel with size {}'.format(
pipeline_model_parallel_size_))
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size) tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size) pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size,
tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data parallel groups. # Build the data-parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized' 'data parallel group is already initialized'
for i in range(model_parallel_size): all_data_parallel_group_ranks = []
ranks = range(i, world_size, model_parallel_size) for i in range(pipeline_model_parallel_size):
group = torch.distributed.new_group(ranks) start_rank = i * num_pipeline_model_parallel_groups
if i == (rank % model_parallel_size): end_rank = (i + 1) * num_pipeline_model_parallel_groups
_DATA_PARALLEL_GROUP = group for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank,
# Build the model parallel groups. tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \ assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized' 'model parallel group is already initialized'
for i in range(world_size // model_parallel_size): for i in range(data_parallel_size):
ranks = range(i * model_parallel_size, ranks = [data_parallel_group_ranks[i]
(i + 1) * model_parallel_size) for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size): if rank in ranks:
_MODEL_PARALLEL_GROUP = group _MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size,
num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False return False
return True return True
...@@ -101,6 +174,20 @@ def get_model_parallel_group(): ...@@ -101,6 +174,20 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP return _MODEL_PARALLEL_GROUP
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'pipeline_model parallel group is not initialized'
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group(): def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to.""" """Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \ assert _DATA_PARALLEL_GROUP is not None, \
...@@ -108,41 +195,97 @@ def get_data_parallel_group(): ...@@ -108,41 +195,97 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
def set_model_parallel_world_size(world_size): def get_embedding_group():
"""Set the model parallel size""" """Get the embedding group the caller rank belongs to."""
global _MPU_WORLD_SIZE assert _EMBEDDING_GROUP is not None, \
_MPU_WORLD_SIZE = world_size 'embedding group is not initialized'
return _EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global _MPU_WORLD_SIZE
if _MPU_WORLD_SIZE is not None:
return _MPU_WORLD_SIZE
return torch.distributed.get_world_size(group=get_model_parallel_group())
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def set_model_parallel_rank(rank):
"""Set model parallel rank."""
global _MPU_RANK
_MPU_RANK = rank
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
global _MPU_RANK
if _MPU_RANK is not None:
return _MPU_RANK
return torch.distributed.get_rank(group=get_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zeor def set_pipeline_model_parallel_rank(rank):
in the model parallel group.""" """Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage():
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage():
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
...@@ -156,7 +299,9 @@ def get_data_parallel_rank(): ...@@ -156,7 +299,9 @@ def get_data_parallel_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
...@@ -35,12 +35,12 @@ except Exception as e: ...@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!') 'instead of apex.normalization.FusedLayerNorm!')
from torch.nn import LayerNorm from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .mappings import copy_to_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method, ...@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1): partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True weight.tensor_model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
...@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
the relevant chunk.""" the relevant chunk."""
weight.model_parallel = True weight.tensor_model_parallel = True
weight.partition_dim = partition_dim weight.partition_dim = partition_dim
weight.partition_stride = stride weight.partition_stride = stride
...@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
with torch.no_grad(): with torch.no_grad():
...@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False self.scale_grad_by_freq = False
self.sparse = False self.sparse = False
self._weight = None self._weight = None
self.model_parallel_size = get_model_parallel_world_size() self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \ self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size( VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(), self.num_embeddings, get_tensor_model_parallel_rank(),
self.model_parallel_size) self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
...@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim=0, stride=1) partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
input_mask = (input_ < self.vocab_start_index) | \ input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index) (input_ >= self.vocab_end_index)
...@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.norm_type, self.scale_grad_by_freq, self.norm_type, self.scale_grad_by_freq,
self.sparse) self.sparse)
# Mask the output embedding. # Mask the output embedding.
if self.model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0 output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel) output = reduce_from_tensor_model_parallel_region(output_parallel)
return output return output
...@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
...@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.model_parallel = True self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
self.bias.stride = stride self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
...@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
...@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
...@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
output_bias = None output_bias = None
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import torch import torch
from .initialize import get_model_parallel_group, get_model_parallel_world_size, get_model_parallel_rank from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -23,11 +23,11 @@ def _reduce(input_): ...@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_model_parallel_world_size()==1: if get_tensor_model_parallel_world_size()==1:
return input_ return input_
# All-reduce. # All-reduce.
torch.distributed.all_reduce(input_, group=get_model_parallel_group()) torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_ return input_
...@@ -36,7 +36,7 @@ def _split(input_): ...@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size==1:
return input_ return input_
...@@ -45,7 +45,7 @@ def _split(input_): ...@@ -45,7 +45,7 @@ def _split(input_):
input_list = split_tensor_along_last_dim(input_, world_size) input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous() output = input_list[rank].contiguous()
return output return output
...@@ -54,18 +54,18 @@ def _split(input_): ...@@ -54,18 +54,18 @@ def _split(input_):
def _gather(input_): def _gather(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size==1:
return input_ return input_
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor. # Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous() output = torch.cat(tensor_list, dim=last_dim).contiguous()
...@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions. # Helper functions.
# ----------------- # -----------------
def copy_to_model_parallel_region(input_): def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_) return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_): def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_): def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
...@@ -28,9 +28,9 @@ from megatron import get_args ...@@ -28,9 +28,9 @@ from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .initialize import get_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
...@@ -45,8 +45,8 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -45,8 +45,8 @@ def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations.""" """Initializ the memory buffer for the checkpointed activations."""
args = get_args() args = get_args()
per_layer = args.batch_size * args.max_position_embeddings * \ per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.model_parallel_size args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \ assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers' 'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
...@@ -54,7 +54,7 @@ def init_checkpointed_activations_memory_buffer(): ...@@ -54,7 +54,7 @@ def init_checkpointed_activations_memory_buffer():
dtype = torch.half dtype = torch.half
if not args.fp16: if not args.fp16:
dtype = torch.float dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \ assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.' 'checkpointed activations memory buffer is already allocated.'
...@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) data = tensor.view(-1)
partition_size = torch.numel(data) // get_model_parallel_world_size() partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_model_parallel_rank() start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] return data[start_index:end_index]
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor) numel = torch.numel(tensor)
numel_gathered = world_size * numel numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
...@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor): ...@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, torch.distributed.all_gather(chunks, tensor,
group=get_model_parallel_group()) group=get_tensor_model_parallel_group())
return gathered return gathered
...@@ -215,15 +215,15 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -215,15 +215,15 @@ def model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across set of model parallel GPUs but different across
different model paralle groups. This is used for different model paralle groups. This is used for
example for dropout in the non-model-parallel regions. example for dropout in the non-tensor-model-parallel regions.
model-parallel state: This state is different among a set of model tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in groups. This is used for example for dropout in
model parallel regions. model parallel regions.
""" """
# 2718 is just for fun and any POSITIVE value will work. # 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718 offset = seed + 2718
model_parallel_seed = offset + get_model_parallel_rank() tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
...@@ -231,15 +231,15 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -231,15 +231,15 @@ def model_parallel_cuda_manual_seed(seed):
print('> initializing model parallel cuda seeds on global rank {}, ' print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with ' 'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format( 'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_model_parallel_rank(), torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), model_parallel_seed, get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True) data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state. # and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
model_parallel_seed) tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
...@@ -268,11 +268,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -268,11 +268,11 @@ class CheckpointFunction(torch.autograd.Function):
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(
args[0].data) args[0].data)
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
return outputs return outputs
@staticmethod @staticmethod
......
...@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, ...@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity = IdentityLayer((batch_size, seq_length, vocab_size), identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda() scale=logits_scale).cuda()
logits = identity() logits = identity()
logits_parallel = mpu.scatter_to_model_parallel_region(logits) logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
target = torch.cuda.LongTensor( target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size) size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
...@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, ...@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return loss, identity.weight.grad return loss, identity.weight.grad
def test_cross_entropy(model_parallel_size): def test_cross_entropy(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'. print('> testing cross entropy with model parallel size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 13 batch_size = 13
seq_length = 17 seq_length = 17
vocab_size_per_partition = 11 vocab_size_per_partition = 11
logits_scale = 1000.0 logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size vocab_size = vocab_size_per_partition * tensor_model_parallel_size
seed = 1234 seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
...@@ -89,7 +89,7 @@ def test_cross_entropy(model_parallel_size): ...@@ -89,7 +89,7 @@ def test_cross_entropy(model_parallel_size):
assert error < 1.0e-6 assert error < 1.0e-6
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_tensor_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -101,8 +101,8 @@ if __name__ == '__main__': ...@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test cross entropy') print_separator('test cross entropy')
test_cross_entropy(model_parallel_size) test_cross_entropy(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -24,15 +24,15 @@ import sys ...@@ -24,15 +24,15 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_boradcast_data(model_parallel_size): def test_broadcast_data(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing boradcast_data with model parallel size {} ...'. print('> testing broadcast_data with model parallel size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank()) torch.manual_seed(1234 + mpu.get_data_parallel_rank())
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
key_size_t = {'key1': [7, 11], key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1], 'key2': [8, 2, 1],
...@@ -48,7 +48,7 @@ def test_boradcast_data(model_parallel_size): ...@@ -48,7 +48,7 @@ def test_boradcast_data(model_parallel_size):
data_t[key] = data[key].clone() data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone() data_t['keyX'] = data['keyX'].clone()
if mpu.get_model_parallel_rank() != 0: if mpu.get_tensor_model_parallel_rank() != 0:
data = None data = None
data_utils._check_data_types(keys, data_t, torch.int64) data_utils._check_data_types(keys, data_t, torch.int64)
...@@ -69,7 +69,7 @@ def test_boradcast_data(model_parallel_size): ...@@ -69,7 +69,7 @@ def test_boradcast_data(model_parallel_size):
assert data_b[key].sub(tensor).abs().max() == 0 assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_tensor_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -81,8 +81,8 @@ if __name__ == '__main__': ...@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test test boradcast data') print_separator('test test broadcast data')
test_boradcast_data(model_parallel_size) test_broadcast_data(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -21,15 +21,15 @@ import sys ...@@ -21,15 +21,15 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_initialize_model_parallel(model_parallel_size): def test_initialize_model_parallel(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format( print('> testing initialize_model_parallel with size {} ...'.format(
model_parallel_size)) tensor_model_parallel_size))
model_parallel_size_ = min(model_parallel_size, tensor_model_parallel_size_ = min(tensor_model_parallel_size,
torch.distributed.get_world_size()) torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized() assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_) mpu.initialize_model_parallel(tensor_model_parallel_size_)
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
# Checks. # Checks.
...@@ -38,15 +38,15 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -38,15 +38,15 @@ def test_initialize_model_parallel(model_parallel_size):
assert rank == torch.distributed.get_rank(group=group) assert rank == torch.distributed.get_rank(group=group)
# Model parallel. # Model parallel.
world_size = model_parallel_size_ world_size = tensor_model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_ rank = torch.distributed.get_rank() % tensor_model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size() assert world_size == mpu.get_tensor_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank() assert rank == mpu.get_tensor_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank) check(mpu.get_tensor_model_parallel_group(), world_size, rank)
# Data parallel. # Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_ world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size rank = torch.distributed.get_rank() // tensor_model_parallel_size
assert world_size == mpu.get_data_parallel_world_size() assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank() assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank) check(mpu.get_data_parallel_group(), world_size, rank)
...@@ -59,20 +59,20 @@ def test_initialize_model_parallel(model_parallel_size): ...@@ -59,20 +59,20 @@ def test_initialize_model_parallel(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_get_model_parallel_src_rank(model_parallel_size_): def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing get_model_parallel_src_rank with size {} ...'.format( print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
model_parallel_size_)) tensor_model_parallel_size_))
model_parallel_size = min(model_parallel_size_, tensor_model_parallel_size = min(tensor_model_parallel_size_,
torch.distributed.get_world_size()) torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized() assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
assert mpu.model_parallel_is_initialized() assert mpu.model_parallel_is_initialized()
# Checks # Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank assert mpu.get_tensor_model_parallel_src_rank() == src_rank
# Reset groups # Reset groups
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
...@@ -86,10 +86,10 @@ if __name__ == '__main__': ...@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test initialize model parallel') print_separator('test initialize model parallel')
test_initialize_model_parallel(model_parallel_size) test_initialize_model_parallel(tensor_model_parallel_size)
print_separator('test model parallel source rank') print_separator('test model parallel source rank')
test_get_model_parallel_src_rank(model_parallel_size) test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -26,14 +26,14 @@ import sys ...@@ -26,14 +26,14 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_parallel_embedding(model_parallel_size): def test_parallel_embedding(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'. print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
batch_size = 17 batch_size = 17
seq_length = 23 seq_length = 23
...@@ -80,16 +80,16 @@ def test_parallel_embedding(model_parallel_size): ...@@ -80,16 +80,16 @@ def test_parallel_embedding(model_parallel_size):
assert error < 1.0e-12, 'error: {}'.format(error) assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size, hidden_size // tensor_model_parallel_size,
1)[mpu.get_model_parallel_rank()] 1)[mpu.get_tensor_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format( print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error)) torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error) assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad, weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size, vocab_size // tensor_model_parallel_size,
0)[mpu.get_model_parallel_rank()] 0)[mpu.get_tensor_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub( error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max() weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format( print(' error in grad (vocab parallel) on global rank {}: {}'.format(
...@@ -104,19 +104,19 @@ def test_parallel_embedding(model_parallel_size): ...@@ -104,19 +104,19 @@ def test_parallel_embedding(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size): def test_initialize_affine_weight(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel ' print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
# --------------- # ---------------
# Column parallel # Column parallel
...@@ -131,7 +131,7 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -131,7 +131,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
master_weight = torch.empty(output_size, input_size) master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight) torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff, my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone() dim=0)[rank].contiguous().clone()
...@@ -154,7 +154,7 @@ def test_initialize_affine_weight(model_parallel_size): ...@@ -154,7 +154,7 @@ def test_initialize_affine_weight(model_parallel_size):
set_random_seed(seed) set_random_seed(seed)
master_weight = torch.empty(output_size, input_size) master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight) torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff, my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone() dim=1)[rank].contiguous().clone()
...@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module): ...@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight return self.weight
def test_column_parallel_linear(model_parallel_size): def test_column_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel ' print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
# Network # Network
...@@ -219,7 +219,7 @@ def test_column_parallel_linear(model_parallel_size): ...@@ -219,7 +219,7 @@ def test_column_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A) dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff, my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone() dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
...@@ -250,20 +250,20 @@ def test_column_parallel_linear(model_parallel_size): ...@@ -250,20 +250,20 @@ def test_column_parallel_linear(model_parallel_size):
print(' >> passed the test :-)') print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size): def test_row_parallel_linear(tensor_model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel ' print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
input_size_coeff = 13 input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size input_size = input_size_coeff * tensor_model_parallel_size
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
# Network # Network
...@@ -286,7 +286,7 @@ def test_row_parallel_linear(model_parallel_size): ...@@ -286,7 +286,7 @@ def test_row_parallel_linear(model_parallel_size):
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A) dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff, my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone() dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max() error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
...@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module): ...@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return self.weight return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, def parallel_self_attention(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length): sequence_length):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
...@@ -352,17 +352,17 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, ...@@ -352,17 +352,17 @@ def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
# Backward # Backward
loss.backward() loss.backward()
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \ return rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size): def test_parallel_self_attention(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel ' print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
...@@ -370,14 +370,14 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -370,14 +370,14 @@ def test_parallel_self_attention(model_parallel_size):
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ rank_1, hideen_size_1, tensor_model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 = parallel_self_attention( attention_layer_1, identity_layer_1 = parallel_self_attention(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, tensor_model_parallel_size, loss, \
attention_layer, identity_layer = parallel_self_attention( attention_layer, identity_layer = parallel_self_attention(
model_parallel_size, num_att_heads_per_partition, tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size assert hideen_size_1 == hidden_size
...@@ -389,7 +389,7 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -389,7 +389,7 @@ def test_parallel_self_attention(model_parallel_size):
my_lin_grad_list = torch.split( my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad, attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size] hidden_size // tensor_model_parallel_size, 0)[rank::tensor_model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0) my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub( error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max() attention_layer.query_key_value.weight.grad).abs().max()
...@@ -410,11 +410,11 @@ def test_parallel_self_attention(model_parallel_size): ...@@ -410,11 +410,11 @@ def test_parallel_self_attention(model_parallel_size):
print(' >> passed the test :-)') print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition, def parallel_transformer(tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length): hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed = 12345 seed = 12345
set_random_seed(seed) set_random_seed(seed)
...@@ -440,31 +440,31 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition, ...@@ -440,31 +440,31 @@ def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
# Backward # Backward
loss.backward() loss.backward()
rank = mpu.get_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \ return rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size): def test_parallel_transformer_layer(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel ' print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size)) 'size: {}'.format(tensor_model_parallel_size))
num_att_heads_per_partition = 3 num_att_heads_per_partition = 3
hidden_size_per_att_head = 7 hidden_size_per_att_head = 7
batch_size = 5 batch_size = 5
sequence_length = 13 sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \ rank_1, hidden_size_1, tensor_model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer( transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition, 1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length) hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \ rank, hidden_size, tensor_model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer( transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition, tensor_model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length) hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max() error = loss_1.sub(loss).abs().max()
...@@ -494,37 +494,37 @@ if __name__ == '__main__': ...@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight') print_separator('test initialize affine weight')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size) test_initialize_affine_weight(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test parallel embedding') print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size) test_parallel_embedding(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test column-parallel linear') print_separator('test column-parallel linear')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size) test_column_parallel_linear(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test row-parallel linear') print_separator('test row-parallel linear')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size) test_row_parallel_linear(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test parallel self-attention') print_separator('test parallel self-attention')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size) test_parallel_self_attention(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
print_separator('test parallel transformer') print_separator('test parallel transformer')
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size) test_parallel_transformer_layer(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
...@@ -21,14 +21,14 @@ import sys ...@@ -21,14 +21,14 @@ import sys
sys.path.append("../..") sys.path.append("../..")
def test_set_cuda_rng_state(model_parallel_size): def test_set_cuda_rng_state(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'. print('> testing set_rng_state with size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
size = 123 size = 123
seed = 1234 seed = 1234
...@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(model_parallel_size): ...@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size): def test_cuda_rng_tracker(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'. print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
seed_1 = 1234 seed_1 = 1234
seed_2 = 4321 seed_2 = 4321
...@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(model_parallel_size): ...@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(model_parallel_size):
print('>> passed the test :-)') print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size): def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'. print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size)) format(tensor_model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(tensor_model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size() tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345) mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345 assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 + assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank()) mpu.get_tensor_model_parallel_rank())
# Reset the tracker # Reset the tracker
mpu.get_cuda_rng_tracker().reset() mpu.get_cuda_rng_tracker().reset()
...@@ -185,20 +185,20 @@ if __name__ == '__main__': ...@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test set rng state') print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size) test_set_cuda_rng_state(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test cuda rng tracker') print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size) test_cuda_rng_tracker(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
model_parallel_size = 1 tensor_model_parallel_size = 1
while model_parallel_size <= world_size: while tensor_model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed') print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size) test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
model_parallel_size *= 2 tensor_model_parallel_size *= 2
This diff is collapsed.
...@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args): ...@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after = orig_vocab_size after = orig_vocab_size
multiple = args.make_vocab_size_divisible_by * \ multiple = args.make_vocab_size_divisible_by * \
args.model_parallel_size args.tensor_model_parallel_size
while (after % multiple) != 0: while (after % multiple) != 0:
after += 1 after += 1
if args.rank == 0: if args.rank == 0:
......
This diff is collapsed.
...@@ -27,14 +27,16 @@ from megatron.checkpointing import save_checkpoint ...@@ -27,14 +27,16 @@ from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
def reduce_losses(losses): def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs.""" """Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat( averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses) torch.distributed.all_reduce(averaged_losses,
reduced_losses = reduced_losses / torch.distributed.get_world_size() group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return reduced_losses return averaged_losses
def report_memory(name): def report_memory(name):
...@@ -48,14 +50,15 @@ def report_memory(name): ...@@ -48,14 +50,15 @@ def report_memory(name):
string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes) string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format( string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes) torch.cuda.max_memory_reserved() / mega_bytes)
print_rank_0(string) if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
def print_params_min_max_norm(optimizer, iteration): def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters.""" """Print min, max, and norm of all parameters."""
index = 0 index = 0
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n' string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer): if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer optimizer_ = optimizer.optimizer
...@@ -66,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -66,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max() max_ = param.data.max()
norm = torch.linalg.norm(param.data) norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel)) iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True) print(string, flush=True)
...@@ -96,11 +99,11 @@ def get_ltor_masks_and_position_ids(data, ...@@ -96,11 +99,11 @@ def get_ltor_masks_and_position_ids(data,
"""Build masks and position id for left to right model.""" """Build masks and position id for left to right model."""
# Extract batch size and sequence length. # Extract batch size and sequence length.
batch_size, seq_length = data.size() micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular). # Attention mask (lower triangular).
if reset_attention_mask: if reset_attention_mask:
att_mask_batch = batch_size att_mask_batch = micro_batch_size
else: else:
att_mask_batch = 1 att_mask_batch = 1
attention_mask = torch.tril(torch.ones( attention_mask = torch.tril(torch.ones(
...@@ -122,7 +125,7 @@ def get_ltor_masks_and_position_ids(data, ...@@ -122,7 +125,7 @@ def get_ltor_masks_and_position_ids(data,
if reset_position_ids or reset_attention_mask: if reset_position_ids or reset_attention_mask:
# Loop through the batches: # Loop through the batches:
for b in range(batch_size): for b in range(micro_batch_size):
# Find indecies where EOD token is. # Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token] eod_index = position_ids[b, data[b] == eod_token]
......
...@@ -23,9 +23,9 @@ from megatron import print_rank_0 ...@@ -23,9 +23,9 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
...@@ -33,10 +33,25 @@ def model_provider(): ...@@ -33,10 +33,25 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel( args = get_args()
num_tokentypes=2, if mpu.get_pipeline_model_parallel_world_size() > 1:
add_binary_head=True, # Determine model based on position of stage in pipeline.
parallel_output=True) if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_pipeline_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
return model return model
...@@ -66,34 +81,51 @@ def get_batch(data_iterator): ...@@ -66,34 +81,51 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
# Forward model. lm_labels if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = model(tokens, padding_mask, lm_loss_, sop_logits = output_tensor
tokentype_ids=types,
lm_labels=lm_labels)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1), sentence_order.view(-1),
ignore_index=-1) ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss = torch.sum( lm_loss_ = lm_loss_.float()
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss]) averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]} return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -23,16 +23,28 @@ from megatron import get_timers ...@@ -23,16 +23,28 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True) args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model return model
...@@ -69,25 +81,42 @@ def get_batch(data_iterator): ...@@ -69,25 +81,42 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator) data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels) # Forward pass through the model.
loss_mask = loss_mask.view(-1) if mpu.is_pipeline_first_stage():
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging. # Reduce loss for logging.
reduced_loss = reduce_losses([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -25,12 +25,13 @@ from megatron import get_timers ...@@ -25,12 +25,13 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args()
return general_ict_model_provider(False, False) return general_ict_model_provider(False, False)
...@@ -72,22 +73,22 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,22 +73,22 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output return output
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop() timers('batch-generator').stop()
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1 global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
...@@ -102,11 +103,12 @@ def forward_step(data_iterator, model): ...@@ -102,11 +103,12 @@ def forward_step(data_iterator, model):
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *topk_accs]) retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])} topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict) stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
return retrieval_loss, stats_dict return retrieval_loss, stats_dict
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment