Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
''' # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
# DeepSpeed Team
"""
Use to partition the activations stored for backward propagation Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption Therefore reduces the memory consumption
Also implements CPU checkpointing and contiguous memory checkpointing Also implements CPU checkpointing and contiguous memory checkpointing
...@@ -10,7 +10,7 @@ Reduces memory consumption and memory fragmentation ...@@ -10,7 +10,7 @@ Reduces memory consumption and memory fragmentation
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
''' """
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
...@@ -82,9 +82,7 @@ def detach_variable(inputs, device=None): ...@@ -82,9 +82,7 @@ def detach_variable(inputs, device=None):
out.append(x) out.append(x)
return tuple(out) return tuple(out)
else: else:
raise RuntimeError( raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
"Only tuple of tensors is supported. Got Unsupported input type: ",
type(inputs).__name__)
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
...@@ -92,7 +90,7 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -92,7 +90,7 @@ def _set_cuda_rng_state(new_state, device=-1):
Arguments: Arguments:
new_state (torch.ByteTensor): The desired state new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
with a single change: the input state is not cloned. Cloning caused with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases. major performance issues for +4 GPU cases.
""" """
...@@ -128,6 +126,7 @@ class CudaRNGStatesTracker: ...@@ -128,6 +126,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
def __init__(self): def __init__(self):
# Map from a string name to the cuda rng state. # Map from a string name to the cuda rng state.
self.states_ = {} self.states_ = {}
...@@ -227,13 +226,9 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -227,13 +226,9 @@ def model_parallel_cuda_manual_seed(seed):
logger.info( logger.info(
'> initializing model parallel cuda seeds on global rank {}, ' '> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with ' 'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format( 'model parallel seed: {} and data parallel seed: {}'.format(dist.get_rank(), tp_rank,
dist.get_rank(), mpu.get_data_parallel_rank(),
tp_rank, model_parallel_seed, data_parallel_seed), )
mpu.get_data_parallel_rank(),
model_parallel_seed,
data_parallel_seed),
)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
get_accelerator().manual_seed(data_parallel_seed) get_accelerator().manual_seed(data_parallel_seed)
...@@ -282,9 +277,7 @@ def gather_partitioned_activations(tensors, device=None): ...@@ -282,9 +277,7 @@ def gather_partitioned_activations(tensors, device=None):
if device is not None: if device is not None:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device) flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
else: else:
flat_tensor = torch.zeros([tensor_size], flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
dtype=item.dtype,
device=item.device)
partitions = [] partitions = []
for i in range(mp_size): for i in range(mp_size):
part_i = flat_tensor.narrow(0, partition_size * i, partition_size) part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
...@@ -384,28 +377,21 @@ def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): ...@@ -384,28 +377,21 @@ def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
i = arg_index - num_non_fp_tensors i = arg_index - num_non_fp_tensors
partition_size = get_partition_size(item) partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow( partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()
0,
get_partition_start(item),
partition_size).clone()
buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
if contiguous_checkpoint: if contiguous_checkpoint:
if i >= len(contiguous_data_buffers): if i >= len(contiguous_data_buffers):
tensor_list = [ tensor_list = [
torch.tensor(()).new_empty([partition_size], torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
dtype=partition.dtype,
device=buffer_device)
for _ in range(num_layers) for _ in range(num_layers)
] ]
contiguous_data_buffers.append(tensor_list) contiguous_data_buffers.append(tensor_list)
data_offsets.append(0) data_offsets.append(0)
elif contiguous_data_buffers[i] is None: elif contiguous_data_buffers[i] is None:
tensor_list = [ tensor_list = [
torch.tensor(()).new_empty([partition_size], torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
dtype=partition.dtype,
device=buffer_device)
for _ in range(num_layers) for _ in range(num_layers)
] ]
contiguous_data_buffers[i] = tensor_list contiguous_data_buffers[i] = tensor_list
...@@ -419,14 +405,10 @@ def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): ...@@ -419,14 +405,10 @@ def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
# previously launched GPU kernels, there is a small # previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously. # window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range( contiguous_data_buffers[i][data_offsets[i]].data[range(
0, 0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
contiguous_data_buffers[i][data_offsets[i]].data.shape[0], int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0
int(mmap.PAGESIZE /
contiguous_data_buffers[i][data_offsets[i]].data.element_size()) contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
)] = 0
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1 data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition) inputs.append(contiguous_partition)
else: else:
...@@ -459,21 +441,14 @@ def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint ...@@ -459,21 +441,14 @@ def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint
if i >= len(contiguous_size_buffers): if i >= len(contiguous_size_buffers):
tmp = torch.tensor(()) tmp = torch.tensor(())
contiguous_size_buffers.append( contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers], tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
dtype=size.dtype,
device=size.device))
size_offsets.append(0) size_offsets.append(0)
elif contiguous_size_buffers[i] is None: elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(()) tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
dtype=size.dtype,
device=size.device)
size_offsets[i] = 0 size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow( contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
0,
size_offsets[i],
numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size) contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size) new_args.append(contiguous_size)
...@@ -499,13 +474,14 @@ def get_cpu_activations_for_backward(args, inputs): ...@@ -499,13 +474,14 @@ def get_cpu_activations_for_backward(args, inputs):
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with """This function is adapted from torch.utils.checkpoint with
two main changes: two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
2) the states in the model parallel tracker are also properly 2) the states in the model parallel tracker are also properly
tracked/set/reset. tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization 3) Performance activation partitioning, contiguous memory optimization
4) CPU Checkpointing 4) CPU Checkpointing
5) Profile forward and backward functions 5) Profile forward and backward functions
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, all_outputs, *args): def forward(ctx, run_function, all_outputs, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME global mpu, timers, SYNCHRONIZE, PROFILE_TIME
...@@ -551,12 +527,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -551,12 +527,9 @@ class CheckpointFunction(torch.autograd.Function):
see_memory_usage("First Forward Beginning", force=False) see_memory_usage("First Forward Beginning", force=False)
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information") logger.info(f"Activation Checkpointing Information")
logger.info(f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info( logger.info(
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
)
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
)
logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
...@@ -564,18 +537,12 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -564,18 +537,12 @@ class CheckpointFunction(torch.autograd.Function):
transport_stream = get_accelerator().Stream(device=cuda_device) transport_stream = get_accelerator().Stream(device=cuda_device)
if PARTITION_ACTIVATIONS: if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
CPU_CHECKPOINT,
CONTIGUOUS_CHECKPOINTING)
elif CPU_CHECKPOINT: elif CPU_CHECKPOINT:
inputs = copy_to_device(args, inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
device=torch.device('cpu'),
criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs # just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
device=cuda_device,
criterion_func=is_activation_to_checkpoint)
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -591,10 +558,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -591,10 +558,7 @@ class CheckpointFunction(torch.autograd.Function):
del inputs_cuda del inputs_cuda
if PARTITION_ACTIVATIONS: if PARTITION_ACTIVATIONS:
new_args = get_partitioned_activations_for_backward( new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
args,
inputs,
CONTIGUOUS_CHECKPOINTING)
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args) save_args_for_backward(*new_args)
elif CPU_CHECKPOINT: elif CPU_CHECKPOINT:
...@@ -613,9 +577,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -613,9 +577,7 @@ class CheckpointFunction(torch.autograd.Function):
if torch.is_tensor(outputs): if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else: else:
non_grad_outputs = [ non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()
]
ctx.mark_non_differentiable(*non_grad_outputs) ctx.mark_non_differentiable(*non_grad_outputs)
if torch.is_tensor(outputs): if torch.is_tensor(outputs):
...@@ -661,14 +623,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -661,14 +623,11 @@ class CheckpointFunction(torch.autograd.Function):
if PARTITION_ACTIVATIONS: if PARTITION_ACTIVATIONS:
# with get_accelerator().stream(transport_stream): # with get_accelerator().stream(transport_stream):
inputs = gather_partitioned_activations( inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
ctx.deepspeed_saved_tensors, device=cuda_device if CPU_CHECKPOINT else None)
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT: elif CPU_CHECKPOINT:
inputs = move_to_device(ctx.deepspeed_saved_tensors, inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
cuda_device,
is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs) detached_inputs = detach_variable(inputs)
else: else:
inputs = ctx.deepspeed_saved_tensors inputs = ctx.deepspeed_saved_tensors
...@@ -762,8 +721,7 @@ def partition_activations_in_checkpoint(partition_activation): ...@@ -762,8 +721,7 @@ def partition_activations_in_checkpoint(partition_activation):
global PARTITION_ACTIVATIONS global PARTITION_ACTIVATIONS
PARTITION_ACTIVATIONS = partition_activation PARTITION_ACTIVATIONS = partition_activation
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info( logger.info(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
def set_num_layers(nlayers): def set_num_layers(nlayers):
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
...@@ -48,16 +47,15 @@ ACT_CHKPT = 'activation_checkpointing' ...@@ -48,16 +47,15 @@ ACT_CHKPT = 'activation_checkpointing'
ACT_CHKPT_DEFAULT = { ACT_CHKPT_DEFAULT = {
ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT, ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT,
ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT, ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY:
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT, ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT,
ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT
} }
class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject):
def __init__(self, param_dict): def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__() super(DeepSpeedActivationCheckpointingConfig, self).__init__()
...@@ -76,29 +74,21 @@ class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): ...@@ -76,29 +74,21 @@ class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject):
self._initialize(act_chkpt_config_dict) self._initialize(act_chkpt_config_dict)
def _initialize(self, act_chkpt_config_dict): def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param( self.partition_activations = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PARTITION_ACTIVATIONS,
act_chkpt_config_dict, ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT)
ACT_CHKPT_PARTITION_ACTIVATIONS,
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) self.contiguous_memory_optimization = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION,
self.contiguous_memory_optimization = get_scalar_param( ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)
act_chkpt_config_dict,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_CPU_CHECKPOINTING,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)
self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_CPU_CHECKPOINTING,
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT)
self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_NUMBER_CHECKPOINTS,
ACT_CHKPT_NUMBER_CHECKPOINTS,
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT)
self.profile = get_scalar_param(act_chkpt_config_dict, self.profile = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PROFILE, ACT_CHKPT_PROFILE_DEFAULT)
ACT_CHKPT_PROFILE,
ACT_CHKPT_PROFILE_DEFAULT)
self.synchronize_checkpoint_boundary = get_scalar_param( self.synchronize_checkpoint_boundary = get_scalar_param(act_chkpt_config_dict,
act_chkpt_config_dict, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)
""" # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
from collections import OrderedDict from collections import OrderedDict
import torch import torch
...@@ -13,29 +14,21 @@ from deepspeed.runtime import ZeROOptimizer ...@@ -13,29 +14,21 @@ from deepspeed.runtime import ZeROOptimizer
from packaging import version as pkg_version from packaging import version as pkg_version
from deepspeed.git_version_info import version from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
clip_tensors_by_global_norm, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
DummyOptim, is_model_parallel_parameter, see_memory_usage)
align_dense_tensors,
all_gather_dp_groups,
bwc_tensor_model_parallel_rank,
is_model_parallel_parameter,
see_memory_usage)
from deepspeed.utils import link_hp_params, fragment_address from deepspeed.utils import link_hp_params, fragment_address
from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
PARTITION_COUNT, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
CLIP_GRAD,
GROUP_PADDINGS,
PARAM_SLICE_MAPPINGS) PARAM_SLICE_MAPPINGS)
setattr(sys.modules[__name__], 'fragment_address', fragment_address) setattr(sys.modules[__name__], 'fragment_address', fragment_address)
class BF16_Optimizer(ZeROOptimizer): class BF16_Optimizer(ZeROOptimizer):
def __init__(self, def __init__(self,
init_optimizer, init_optimizer,
param_names, param_names,
...@@ -58,9 +51,7 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -58,9 +51,7 @@ class BF16_Optimizer(ZeROOptimizer):
self.allgather_bucket_size = int(allgather_bucket_size) self.allgather_bucket_size = int(allgather_bucket_size)
self.dp_process_group = dp_process_group self.dp_process_group = dp_process_group
self.dp_rank = dist.get_rank(group=self.dp_process_group) self.dp_rank = dist.get_rank(group=self.dp_process_group)
self.real_dp_process_group = [ self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
dp_process_group for i in range(len(self.optimizer.param_groups))
]
# Load pre-built or JIT compile (un)flatten ops # Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load() util_ops = UtilsBuilder().load()
...@@ -95,9 +86,7 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -95,9 +86,7 @@ class BF16_Optimizer(ZeROOptimizer):
def _setup_for_real_optimizer(self): def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group) dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [ self.partition_count = [dp_world_size for i in range(len(self.optimizer.param_groups))]
dp_world_size for i in range(len(self.optimizer.param_groups))
]
for i, param_group in enumerate(self.optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
see_memory_usage(f'before initializing group {i}', force=True) see_memory_usage(f'before initializing group {i}', force=True)
...@@ -105,69 +94,55 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -105,69 +94,55 @@ class BF16_Optimizer(ZeROOptimizer):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
# grab the original list # grab the original list
self.bf16_groups.append(param_group['params']) trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
self.bf16_groups.append(trainable_parameters)
# create flat bf16 params # create flat bf16 params
self.bf16_groups_flat.append( self.bf16_groups_flat.append(
self._flatten_dense_tensors_aligned( self._flatten_dense_tensors_aligned(self.bf16_groups[i],
self.bf16_groups[i], self.nccl_start_alignment_factor * dp_world_size))
self.nccl_start_alignment_factor * dp_world_size))
# Make bf16 params point to flat tensor storage # Make bf16 params point to flat tensor storage
self._update_storage_to_flattened_tensor( self._update_storage_to_flattened_tensor(tensor_list=self.bf16_groups[i],
tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i])
flat_tensor=self.bf16_groups_flat[i])
# divide flat weights into equal sized partitions # divide flat weights into equal sized partitions
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
bf16_dp_partitions = [ bf16_dp_partitions = [
self.bf16_groups_flat[i].narrow(0, self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size)
dp_index * partition_size,
partition_size)
for dp_index in range(dp_world_size) for dp_index in range(dp_world_size)
] ]
self.bf16_partitioned_groups.append(bf16_dp_partitions) self.bf16_partitioned_groups.append(bf16_dp_partitions)
# create fp32 params partition # create fp32 params partition
self.fp32_groups_flat_partition.append( self.fp32_groups_flat_partition.append(bf16_dp_partitions[partition_id].clone().float().detach())
bf16_dp_partitions[partition_id].clone().float().detach())
self.fp32_groups_flat_partition[i].requires_grad = True self.fp32_groups_flat_partition[i].requires_grad = True
num_elem_list = [t.numel() for t in self.bf16_groups[i]] num_elem_list = [t.numel() for t in self.bf16_groups[i]]
# create fp32 gradients # create fp32 gradients
self.fp32_groups_gradients_flat.append( self.fp32_groups_gradients_flat.append(torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32))
torch.zeros_like(self.bf16_groups_flat[i],
dtype=torch.float32))
# track individual fp32 gradients for entire model # track individual fp32 gradients for entire model
fp32_gradients = self._split_flat_tensor( fp32_gradients = self._split_flat_tensor(flat_tensor=self.fp32_groups_gradients_flat[i],
flat_tensor=self.fp32_groups_gradients_flat[i], num_elem_list=num_elem_list)
num_elem_list=num_elem_list)
self.fp32_groups_gradients.append(fp32_gradients) self.fp32_groups_gradients.append(fp32_gradients)
self.fp32_groups_gradient_dict[i] = fp32_gradients self.fp32_groups_gradient_dict[i] = fp32_gradients
# flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding) # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
length_without_padding = sum(num_elem_list) length_without_padding = sum(num_elem_list)
self.fp32_groups_actual_gradients_flat.append( self.fp32_groups_actual_gradients_flat.append(
torch.narrow(self.fp32_groups_gradients_flat[i], torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding))
0,
0,
length_without_padding))
# flat tensor corresponding to gradient partition # flat tensor corresponding to gradient partition
self.fp32_groups_gradient_flat_partition.append( self.fp32_groups_gradient_flat_partition.append(
torch.narrow(self.fp32_groups_gradients_flat[i], torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size))
0,
partition_id * partition_size,
partition_size))
# track fp32 gradient updates # track fp32 gradient updates
self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i])) self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i]))
# Record padding required for alignment # Record padding required for alignment
if partition_id == dist.get_world_size( if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
group=self.real_dp_process_group[i]) - 1:
padding = self.bf16_groups_flat[i].numel() - length_without_padding padding = self.bf16_groups_flat[i].numel() - length_without_padding
else: else:
padding = 0 padding = 0
...@@ -199,8 +174,7 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -199,8 +174,7 @@ class BF16_Optimizer(ZeROOptimizer):
for lp in self.bf16_groups[i]: for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None: if lp._hp_mapping is not None:
lp_name = self.param_names[lp] lp_name = self.param_names[lp]
param_mapping_per_group[ param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
lp_name] = lp._hp_mapping.get_hp_fragment_address()
param_mapping.append(param_mapping_per_group) param_mapping.append(param_mapping_per_group)
return param_mapping return param_mapping
...@@ -212,17 +186,16 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -212,17 +186,16 @@ class BF16_Optimizer(ZeROOptimizer):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bf16_groups_flat[i].numel() // dp_world_size partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
flat_hp_partition = self.fp32_groups_flat_partition[i] flat_hp_partition = self.fp32_groups_flat_partition[i]
link_hp_params( link_hp_params(lp_param_list=self.bf16_groups[i],
lp_param_list=self.bf16_groups[i], flat_hp_partition=flat_hp_partition,
flat_hp_partition=flat_hp_partition, gradient_dict=self.fp32_groups_gradient_dict,
gradient_dict=self.fp32_groups_gradient_dict, offload_gradient_dict=None,
offload_gradient_dict=None, use_offload=False,
use_offload=False, param_group_index=i,
param_group_index=i, partition_start=partition_id * partition_size,
partition_start=partition_id * partition_size, partition_size=partition_size,
partition_size=partition_size, partition_optimizer_state=self.optimizer.state[flat_hp_partition],
partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i])
dp_group=self.real_dp_process_group[i])
def initialize_optimizer_states(self): def initialize_optimizer_states(self):
"""Take an optimizer step with zero-valued gradients to allocate internal """Take an optimizer step with zero-valued gradients to allocate internal
...@@ -231,7 +204,8 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -231,7 +204,8 @@ class BF16_Optimizer(ZeROOptimizer):
This helps prevent memory fragmentation by allocating optimizer state at the This helps prevent memory fragmentation by allocating optimizer state at the
beginning of training instead of after activations have been allocated. beginning of training instead of after activations have been allocated.
""" """
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition, self.fp32_groups_gradient_flat_partition): for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
self.fp32_groups_gradient_flat_partition):
param_partition.grad = grad_partition param_partition.grad = grad_partition
self.optimizer.step() self.optimizer.step()
...@@ -262,19 +236,17 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -262,19 +236,17 @@ class BF16_Optimizer(ZeROOptimizer):
if closure is not None: if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.') raise NotImplementedError(f'{self.__class__} does not support closure.')
all_groups_norm = get_global_norm_of_tensors( all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(),
input_tensors=self.get_grads_for_norm(), mpu=self.mpu,
mpu=self.mpu, norm_type=self.norm_type)
norm_type=self.norm_type)
self._global_grad_norm = all_groups_norm self._global_grad_norm = all_groups_norm
assert all_groups_norm > 0. assert all_groups_norm > 0.
if self.clip_grad > 0.: if self.clip_grad > 0.:
clip_tensors_by_global_norm( clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True),
input_tensors=self.get_grads_for_norm(for_clipping=True), max_norm=self.clip_grad,
max_norm=self.clip_grad, global_norm=all_groups_norm,
global_norm=all_groups_norm, mpu=self.mpu)
mpu=self.mpu)
self.optimizer.step() self.optimizer.step()
...@@ -343,7 +315,8 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -343,7 +315,8 @@ class BF16_Optimizer(ZeROOptimizer):
@torch.no_grad() @torch.no_grad()
def update_lp_params(self): def update_lp_params(self):
for i, (bf16_partitions, fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): for i, (bf16_partitions,
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
bf16_partitions[partition_id].data.copy_(fp32_partition.data) bf16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
...@@ -395,18 +368,11 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -395,18 +368,11 @@ class BF16_Optimizer(ZeROOptimizer):
load_optimizer_states=True, load_optimizer_states=True,
load_from_fp32_weights=False): load_from_fp32_weights=False):
if checkpoint_folder: if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
load_optimizer_states,
load_from_fp32_weights)
else: else:
self._load_legacy_checkpoint(state_dict_list, self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
load_optimizer_states,
load_from_fp32_weights)
def _load_legacy_checkpoint(self, def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
dp_rank = dist.get_rank(group=self.dp_process_group) dp_rank = dist.get_rank(group=self.dp_process_group)
current_rank_sd = state_dict_list[dp_rank] current_rank_sd = state_dict_list[dp_rank]
...@@ -421,17 +387,15 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -421,17 +387,15 @@ class BF16_Optimizer(ZeROOptimizer):
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE]) self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
if load_from_fp32_weights: if load_from_fp32_weights:
for current, saved in zip(self.fp32_groups_flat_partition, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): for current, saved in zip(self.fp32_groups_flat_partition,
current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
src_tensor = _get_padded_tensor(saved, current.numel()) src_tensor = _get_padded_tensor(saved, current.numel())
current.data.copy_(src_tensor.data) current.data.copy_(src_tensor.data)
if load_optimizer_states: if load_optimizer_states:
self._link_all_hp_params() self._link_all_hp_params()
def _load_universal_checkpoint(self, def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
checkpoint_folder,
load_optimizer_states,
load_from_fp32_weights):
self._load_hp_checkpoint_state(checkpoint_folder) self._load_hp_checkpoint_state(checkpoint_folder)
@property @property
...@@ -448,11 +412,8 @@ class BF16_Optimizer(ZeROOptimizer): ...@@ -448,11 +412,8 @@ class BF16_Optimizer(ZeROOptimizer):
for lp in self.bf16_groups[i]: for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None: if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state( lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
os.path.join(checkpoint_dir, tp_world_size)
self.param_names[lp]),
tp_rank,
tp_world_size)
def _get_padded_tensor(src_tensor, size): def _get_padded_tensor(src_tensor, size):
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
class CheckpointEngine(object): class CheckpointEngine(object):
...@@ -11,6 +16,9 @@ class CheckpointEngine(object): ...@@ -11,6 +16,9 @@ class CheckpointEngine(object):
# create checkpoint on give tag for save/load. # create checkpoint on give tag for save/load.
pass pass
def makedirs(self, path, exist_ok=False):
os.makedirs(path, exist_ok=exist_ok)
def save(self, state_dict, path: str): def save(self, state_dict, path: str):
pass pass
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
import torch import torch
...@@ -15,6 +18,7 @@ def _get_tag_from_path(path): ...@@ -15,6 +18,7 @@ def _get_tag_from_path(path):
class NebulaCheckpointEngine(CheckpointEngine): class NebulaCheckpointEngine(CheckpointEngine):
def __init__(self, config_params=None): def __init__(self, config_params=None):
super().__init__(config_params) super().__init__(config_params)
self.checkpoint = None self.checkpoint = None
...@@ -27,8 +31,7 @@ class NebulaCheckpointEngine(CheckpointEngine): ...@@ -27,8 +31,7 @@ class NebulaCheckpointEngine(CheckpointEngine):
nebula_config_params = { nebula_config_params = {
NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path, NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path,
NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval, NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval,
NEBULA_NUM_OF_VERSION_IN_RETENTION: NEBULA_NUM_OF_VERSION_IN_RETENTION: config_params.num_of_version_in_retention,
config_params.num_of_version_in_retention,
} }
torch_nebula.init(**nebula_config_params) torch_nebula.init(**nebula_config_params)
...@@ -54,16 +57,13 @@ class NebulaCheckpointEngine(CheckpointEngine): ...@@ -54,16 +57,13 @@ class NebulaCheckpointEngine(CheckpointEngine):
first_load_flag = self.tag_flag is None or self.tag_flag == tag first_load_flag = self.tag_flag is None or self.tag_flag == tag
if not self.enable_nebula_load and first_load_flag: if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag self.tag_flag = tag
logger.info( logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location) partition = torch.load(path, map_location=map_location)
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .") logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition return partition
partititon_name = os.path.basename(path) partititon_name = os.path.basename(path)
logger.info( logger.info(f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}...")
f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}..."
)
checkpoint = None checkpoint = None
if tag in (None, 'latest', 'latest_universal'): if tag in (None, 'latest', 'latest_universal'):
...@@ -72,44 +72,36 @@ class NebulaCheckpointEngine(CheckpointEngine): ...@@ -72,44 +72,36 @@ class NebulaCheckpointEngine(CheckpointEngine):
# will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary # will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary
# when met failure loading for given tag, the loading priority would be like: # when met failure loading for given tag, the loading priority would be like:
# nebula tier3 latest > nebula tier1 latest. # nebula tier3 latest > nebula tier1 latest.
checkpoint = torch_nebula.get_latest_checkpoint( checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
persist_path=self.nebula_load_path)
else: else:
checkpoint = torch_nebula.get_checkpoint(tag=tag, checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=self.nebula_load_path)
persist_path=self.nebula_load_path)
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
logger.info( logger.info(
f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!" f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!"
) )
# nebula tier3 latest # nebula tier3 latest
checkpoint = torch_nebula.get_latest_checkpoint( checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
persist_path=self.nebula_load_path)
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
logger.info( logger.info(
f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!" f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!"
) )
# nebula tier1 latest # nebula tier1 latest
checkpoint = torch_nebula.get_latest_checkpoint() checkpoint = torch_nebula.get_latest_checkpoint()
logger.warning( logger.warning(f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
return None return None
tag = checkpoint.tag tag = checkpoint.tag
self.tag_flag = -1 self.tag_flag = -1
partition = checkpoint.load(partititon_name, map_location=map_location) partition = checkpoint.load(partititon_name, map_location=map_location)
logger.info( logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
return partition return partition
def commit(self, tag): def commit(self, tag):
# nebula commit will be call when all files under give tag are ready to be persisted in the async way. # nebula commit will be call when all files under give tag are ready to be persisted in the async way.
logger.info( logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting")
f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting"
)
commit_rls = self.checkpoint.commit() commit_rls = self.checkpoint.commit()
if not commit_rls: if not commit_rls:
logger.error( logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.")
f"[Nebula] failed to commit the checkpoint, please check the log.")
return False return False
return commit_rls return commit_rls
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from deepspeed.utils import logger, log_dist from deepspeed.utils import logger, log_dist
...@@ -7,11 +10,12 @@ from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ ...@@ -7,11 +10,12 @@ from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
class TorchCheckpointEngine(CheckpointEngine): class TorchCheckpointEngine(CheckpointEngine):
def __init__(self, config_params=None): def __init__(self, config_params=None):
super().__init__(config_params) super().__init__(config_params)
def create(self, tag): def create(self, tag):
log_dist(f"[Torch] Checkpoint {tag} is begin to save!", ranks=[0]) log_dist(f"[Torch] Checkpoint {tag} is about to be saved!", ranks=[0])
def save(self, state_dict, path: str): def save(self, state_dict, path: str):
logger.info(f"[Torch] Saving {path}...") logger.info(f"[Torch] Saving {path}...")
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
"""batched collective operations for overhead amortization and better # SPDX-License-Identifier: Apache-2.0
bandwidth utilization"""
# DeepSpeed Team
"""
batched collective operations for overhead amortization and better
bandwidth utilization
"""
import math import math
from typing import List from typing import List
...@@ -15,15 +20,8 @@ import torch.nn.functional ...@@ -15,15 +20,8 @@ import torch.nn.functional
from deepspeed.utils import instrument_w_nvtx from deepspeed.utils import instrument_w_nvtx
def _torch_reduce_scatter_fn(input_tensor: Tensor, def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group=None, async_op=False, prof=False):
output_tensor: Tensor, return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, group=group, async_op=async_op)
group=None,
async_op=False,
prof=False):
return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor,
input_tensor,
group=group,
async_op=async_op)
@instrument_w_nvtx @instrument_w_nvtx
...@@ -45,13 +43,10 @@ def reduce_scatter_coalesced( ...@@ -45,13 +43,10 @@ def reduce_scatter_coalesced(
flattened_tensor = tensor.view(-1) flattened_tensor = tensor.view(-1)
chunk_sz = math.ceil(tensor.numel() / world_sz) chunk_sz = math.ceil(tensor.numel() / world_sz)
partition_lst_for_each_tensor[tensor_idx] = [ partition_lst_for_each_tensor[tensor_idx] = [
flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] for rank in range(0, world_sz)
for rank in range(0,
world_sz)
] ]
padded_partition_sz_for_each_tensor = tuple( padded_partition_sz_for_each_tensor = tuple(math.ceil(t.numel() / world_sz) for t in tensors)
math.ceil(t.numel() / world_sz) for t in tensors)
if len(tensors) == 1 and tensors[0].numel() % world_sz == 0: if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
# if there's only one tensor being reduced and we don't need to pad # if there's only one tensor being reduced and we don't need to pad
...@@ -68,21 +63,15 @@ def reduce_scatter_coalesced( ...@@ -68,21 +63,15 @@ def reduce_scatter_coalesced(
tensor_partitions_lst_with_padding.append(tensor_chunk) tensor_partitions_lst_with_padding.append(tensor_chunk)
# add padding if necessary # add padding if necessary
padding_sz = padded_partition_sz_for_each_tensor[ padding_sz = padded_partition_sz_for_each_tensor[tensor_idx] - tensor_chunk.numel()
tensor_idx] - tensor_chunk.numel()
if padding_sz > 0: if padding_sz > 0:
tensor_partitions_lst_with_padding.append( tensor_partitions_lst_with_padding.append(
torch.empty(padding_sz, torch.empty(padding_sz, dtype=tensor_chunk.dtype, device=tensor_chunk.device))
dtype=tensor_chunk.dtype,
device=tensor_chunk.device))
tensor_partition_flat_buffer = instrument_w_nvtx( tensor_partition_flat_buffer = instrument_w_nvtx(torch.cat)(tensor_partitions_lst_with_padding)
torch.cat)(tensor_partitions_lst_with_padding)
tensor_partition_flat_buffer.div_(world_sz) # pre-divide tensor_partition_flat_buffer.div_(world_sz) # pre-divide
tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk( tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(tensor_partition_flat_buffer, world_sz)
tensor_partition_flat_buffer,
world_sz)
# batched reduce-scatter call # batched reduce-scatter call
_torch_reduce_scatter_fn(tensor_partition_flat_buffer, _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
...@@ -95,9 +84,7 @@ def reduce_scatter_coalesced( ...@@ -95,9 +84,7 @@ def reduce_scatter_coalesced(
offset = 0 offset = 0
for tensor_idx in range(len(tensors)): for tensor_idx in range(len(tensors)):
output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow( output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
0, 0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
offset,
partition_lst_for_each_tensor[tensor_idx][this_rank].numel())
offset += padded_partition_sz_for_each_tensor[tensor_idx] offset += padded_partition_sz_for_each_tensor[tensor_idx]
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
import cupy import cupy
...@@ -12,6 +13,7 @@ from deepspeed.runtime.compression.cupy import CupyBackend ...@@ -12,6 +13,7 @@ from deepspeed.runtime.compression.cupy import CupyBackend
class MpiBackend(object): class MpiBackend(object):
def __init__(self, cuda_aware): def __init__(self, cuda_aware):
self.comm = MPI.COMM_WORLD self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank() self.rank = self.comm.Get_rank()
...@@ -31,49 +33,26 @@ class MpiBackend(object): ...@@ -31,49 +33,26 @@ class MpiBackend(object):
req.append(comm.Isend(sendbuf, dest=root)) req.append(comm.Isend(sendbuf, dest=root))
return req return req
def gather_cuda(self, def gather_cuda(self, rank, world_size, comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale): cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers # We do in-place operations on cupy buffers so we do not return any buffers
requests = [] requests = []
for idx in range(world_size): for idx in range(world_size):
req_sign = self.my_igather(rank, req_sign = self.my_igather(rank, world_size, comm, cupy_sign_list_packed[idx], cupy_recvbuf_sign, root=idx)
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign requests += req_sign
for idx in range(world_size): for idx in range(world_size):
req_scale = self.my_igather(rank, req_scale = self.my_igather(rank, world_size, comm, cupy_worker_scale, cupy_recvbuf_scale, root=idx)
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale requests += req_scale
MPI.Request.Waitall(requests) MPI.Request.Waitall(requests)
def gather_host(self, def gather_host(self, rank, world_size, comm, cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale): cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays # In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers # so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size, numpy_recvbuf_sign = np.zeros([world_size, cupy_sign_list_packed[rank].size],
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype) dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype) numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)
...@@ -101,12 +80,7 @@ class MpiBackend(object): ...@@ -101,12 +80,7 @@ class MpiBackend(object):
requests += req_sign requests += req_sign
for idx in range(world_size): for idx in range(world_size):
req_scale = self.my_igather(rank, req_scale = self.my_igather(rank, world_size, comm, numpy_worker_scale, numpy_recvbuf_scale, root=idx)
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale requests += req_scale
MPI.Request.Waitall(requests) MPI.Request.Waitall(requests)
...@@ -122,30 +96,18 @@ class MpiBackend(object): ...@@ -122,30 +96,18 @@ class MpiBackend(object):
return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale
def allgather_cuda(self, def allgather_cuda(self, comm, cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server): cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server) comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server) comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
def allgather_host(self, def allgather_host(self, comm, cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server): cupy_recvbuf_scale_server):
# 1. Convert cupy to numpy # 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros( numpy_recvbuf_sign_server = np.zeros([comm.Get_size(), cupy_server_sign_packed.size],
[comm.Get_size(), dtype=cupy_server_sign_packed.dtype)
cupy_server_sign_packed.size], numpy_recvbuf_scale_server = np.zeros([comm.Get_size(), 1], dtype=cupy_server_scale.dtype)
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)
numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed) numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server) numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
...@@ -167,11 +129,7 @@ class MpiBackend(object): ...@@ -167,11 +129,7 @@ class MpiBackend(object):
return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
def compressed_allreduce(self, def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):
all_start_time = time.time() all_start_time = time.time()
original_shape = buffer_m.size() original_shape = buffer_m.size()
...@@ -182,104 +140,71 @@ class MpiBackend(object): ...@@ -182,104 +140,71 @@ class MpiBackend(object):
cupy.cuda.Device(local_rank).use() cupy.cuda.Device(local_rank).use()
if original_size != worker_error_size: if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size, empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error) buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
worker_error.set_(buffer_m - worker_scale * worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
cupy_sign_list_packed = self.compression_backend.compress_by_chunk( cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size)
self.size)
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale) cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_recvbuf_sign = cupy.zeros( cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size],
[self.size, dtype=cupy_sign_list_packed[0].dtype)
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
# Communication Phase 1 # Communication Phase 1
gather_start = time.time() gather_start = time.time()
if self.cuda_aware: if self.cuda_aware:
self.gather_cuda(self.rank, self.gather_cuda(self.rank, self.size, self.comm, cupy_sign_list_packed, cupy_recvbuf_sign,
self.size, cupy_worker_scale, cupy_recvbuf_scale)
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else: else:
_, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank, _, cupy_recvbuf_sign, _, cupy_recvbuf_scale = self.gather_host(self.rank, self.size, self.comm,
self.size, cupy_sign_list_packed, cupy_recvbuf_sign,
self.comm, cupy_worker_scale, cupy_recvbuf_scale)
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time() gather_end = time.time()
# cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None # cupy_sign_list_packed, cupy_worker_scale, worker_scale = None, None, None
cupy_sign_list_packed = None cupy_sign_list_packed = None
compensated_server_m = self.compression_backend.cupy2torch( compensated_server_m = self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape( (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
self.size, self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(1 / self.size)).sum(0)
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(
1 / self.size)).sum(0)
compensated_server_m.add_(server_error) compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt( server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
compensated_server_m.numel()) server_error.set_(compensated_server_m -
server_error.set_( server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
compensated_server_m - server_scale *
compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
cupy_server_scale = self.compression_backend.torch2cupy(server_scale) cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
cupy_server_sign_packed = self.compression_backend.compress_by_chunk( cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy( self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1)
compensated_server_m.sign_().add_(1).bool()),
1)
compensated_server_m = None compensated_server_m = None
cupy_recvbuf_sign_server = cupy.zeros( cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size],
[self.size, dtype=cupy_recvbuf_sign.dtype)
cupy_server_sign_packed[0].size], cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_recvbuf_scale.dtype)
dtype=cupy_recvbuf_sign.dtype)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_recvbuf_scale.dtype)
# cupy_recvbuf_sign, cupy_recvbuf_scale = None, None # cupy_recvbuf_sign, cupy_recvbuf_scale = None, None
cupy_recvbuf_sign = None cupy_recvbuf_sign = None
# Communication Phase 2 # Communication Phase 2
if self.cuda_aware: if self.cuda_aware:
self.allgather_cuda(self.comm, self.allgather_cuda(self.comm, cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server) cupy_recvbuf_scale_server)
else: else:
_, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host(self.comm, _, cupy_recvbuf_sign_server, _, cupy_recvbuf_scale_server = self.allgather_host(
cupy_server_sign_packed[0], self.comm, cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale,
cupy_recvbuf_sign_server, cupy_recvbuf_scale_server)
cupy_server_scale,
cupy_recvbuf_scale_server)
# cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None # cupy_server_sign_packed, cupy_server_scale, server_scale = None, None, None
cupy_server_sign_packed = None cupy_server_sign_packed = None
buffer_m.data.copy_( buffer_m.data.copy_(
self.compression_backend.cupy2torch( self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
(cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
self.size, self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data)
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(
cupy_recvbuf_scale_server)).flatten().data)
if original_size != worker_error_size: if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size] buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1: if len(original_shape) > 1:
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from deepspeed import comm as dist from deepspeed import comm as dist
...@@ -12,6 +13,7 @@ from deepspeed.accelerator import get_accelerator ...@@ -12,6 +13,7 @@ from deepspeed.accelerator import get_accelerator
class NcclBackend(object): class NcclBackend(object):
def __init__(self, mpu=None): def __init__(self, mpu=None):
if mpu is None: if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size())) self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
...@@ -24,7 +26,7 @@ class NcclBackend(object): ...@@ -24,7 +26,7 @@ class NcclBackend(object):
self.bool_not_supported = False self.bool_not_supported = False
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR >= 1 and TORCH_MINOR >= 10: if (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) or TORCH_MAJOR == 2:
self.bool_not_supported = True self.bool_not_supported = True
def my_igather(self, rank, size, group, sendbuf, recvbuf, root): def my_igather(self, rank, size, group, sendbuf, recvbuf, root):
...@@ -49,11 +51,7 @@ class NcclBackend(object): ...@@ -49,11 +51,7 @@ class NcclBackend(object):
else: else:
dist.send(sendbuf, group=group, dst=root) dist.send(sendbuf, group=group, dst=root)
def compressed_allreduce(self, def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank):
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):
# all_start_time = time.time() # all_start_time = time.time()
original_shape = buffer_m.size() original_shape = buffer_m.size()
...@@ -64,53 +62,41 @@ class NcclBackend(object): ...@@ -64,53 +62,41 @@ class NcclBackend(object):
cupy.cuda.Device(local_rank).use() cupy.cuda.Device(local_rank).use()
if original_size != worker_error_size: if original_size != worker_error_size:
empty_tensor = torch.zeros(worker_error_size - original_size, empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device)
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m = torch.cat([buffer_m, empty_tensor])
buffer_m.add_(worker_error) buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(buffer_m.numel()) worker_scale = torch.norm(buffer_m) / np.sqrt(buffer_m.numel())
worker_error.set_(buffer_m - worker_scale * worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
if self.bool_not_supported: if self.bool_not_supported:
cupy_sign_list_packed = self.compression_backend.compress_by_chunk( cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy( self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool().to(dtype=torch.uint8)), self.size)
buffer_m.sign_().add_(1).bool().to(dtype=torch.uint8)),
self.size)
else: else:
cupy_sign_list_packed = self.compression_backend.compress_by_chunk( cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()), self.size)
self.size)
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale) cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_recvbuf_sign = cupy.zeros( cupy_recvbuf_sign = cupy.zeros([self.size, cupy_sign_list_packed[self.rank].size],
[self.size, dtype=cupy_sign_list_packed[0].dtype)
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
# cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) # cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
sign_list_packed = [ sign_list_packed = [
self.compression_backend.cupy2torch(cupy_sign_list_packed[idx]) self.compression_backend.cupy2torch(cupy_sign_list_packed[idx]) for idx in range(self.size)
for idx in range(self.size)
] ]
# worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale) # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale)
recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign) recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign)
#recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale) #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale)
recvbuf_scale = [ recvbuf_scale = [
torch.zeros(1, torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(get_accelerator().device_name(local_rank)))
dtype=worker_scale.dtype,
device=torch.device(get_accelerator().device_name(local_rank)))
for i in range(self.size) for i in range(self.size)
] ]
# communication phase 1 # communication phase 1
# gather_start = time.time() # gather_start = time.time()
# Alltoall for sign # Alltoall for sign
dist.all_to_all_single(recvbuf_sign, dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
torch.stack(sign_list_packed),
group=self.world_group)
# Allgather for scale # Allgather for scale
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)
...@@ -123,61 +109,44 @@ class NcclBackend(object): ...@@ -123,61 +109,44 @@ class NcclBackend(object):
#cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale)) #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale))
compensated_server_m = self.compression_backend.cupy2torch( compensated_server_m = self.compression_backend.cupy2torch(
(cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape( (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
self.size, torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
-1)).float().add_(-0.5).mul_(2.0).mul_(
torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
compensated_server_m.add_(server_error) compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt( server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
compensated_server_m.numel()) server_error.set_(compensated_server_m -
server_error.set_( server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
compensated_server_m - server_scale *
compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
# cupy_server_scale = self.compression_backend.torch2cupy(server_scale) # cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
if self.bool_not_supported: if self.bool_not_supported:
cupy_server_sign_packed = self.compression_backend.compress_by_chunk( cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy( self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool().to(dtype=torch.uint8)),
compensated_server_m.sign_().add_(1).bool().to(dtype=torch.uint8)),
1) 1)
else: else:
cupy_server_sign_packed = self.compression_backend.compress_by_chunk( cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
self.compression_backend.torch2cupy( self.compression_backend.torch2cupy(compensated_server_m.sign_().add_(1).bool()), 1)
compensated_server_m.sign_().add_(1).bool()),
1)
compensated_server_m = None compensated_server_m = None
cupy_recvbuf_sign_server = cupy.zeros( cupy_recvbuf_sign_server = cupy.zeros([self.size, cupy_server_sign_packed[0].size],
[self.size, dtype=cupy_recvbuf_sign.dtype)
cupy_server_sign_packed[0].size],
dtype=cupy_recvbuf_sign.dtype)
# cupy_recvbuf_sign, recvbuf_sign = None, None # cupy_recvbuf_sign, recvbuf_sign = None, None
cupy_recvbuf_sign = None cupy_recvbuf_sign = None
server_sign_packed = [ server_sign_packed = [self.compression_backend.cupy2torch(cupy_server_sign_packed[0])]
self.compression_backend.cupy2torch(cupy_server_sign_packed[0])
]
recvbuf_sign_server = [ recvbuf_sign_server = [
self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx]) self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx]) for idx in range(self.size)
for idx in range(self.size)
] ]
# server_scale = self.compression_backend.cupy2torch(cupy_server_scale) # server_scale = self.compression_backend.cupy2torch(cupy_server_scale)
cupy_recvbuf_scale_server = cupy.zeros([self.size, cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)
1],
dtype=cupy_worker_scale.dtype)
# cupy_recvbuf_scale, recvbuf_scale = None, None # cupy_recvbuf_scale, recvbuf_scale = None, None
recvbuf_scale_server = [ recvbuf_scale_server = [
self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx]) self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx]) for idx in range(self.size)
for idx in range(self.size)
] ]
# Communication Phase 2 # Communication Phase 2
dist.all_gather(recvbuf_sign_server, dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
server_sign_packed[0],
group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)
cupy_server_sign_packed = None cupy_server_sign_packed = None
...@@ -186,16 +155,12 @@ class NcclBackend(object): ...@@ -186,16 +155,12 @@ class NcclBackend(object):
# dist.all_gather only provides a tensor list as the recv/output buffer # dist.all_gather only provides a tensor list as the recv/output buffer
recvbuf_sign_server = torch.stack(recvbuf_sign_server) recvbuf_sign_server = torch.stack(recvbuf_sign_server)
cupy_recvbuf_sign_server = self.compression_backend.torch2cupy( cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(recvbuf_sign_server)
recvbuf_sign_server)
buffer_m.data.copy_( buffer_m.data.copy_(
self.compression_backend.cupy2torch( self.compression_backend.cupy2torch((cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
(cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
self.size, self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)).flatten().data)
-1)).float().add_(-0.5).mul_(2.0).mul_(
self.compression_backend.cupy2torch(
cupy_recvbuf_scale_server)).flatten().data)
if original_size != worker_error_size: if original_size != worker_error_size:
buffer_m = buffer_m[0:original_size] buffer_m = buffer_m[0:original_size]
if len(original_shape) > 1: if len(original_shape) > 1:
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import cupy import cupy
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
...@@ -8,6 +9,7 @@ from torch.utils.dlpack import from_dlpack ...@@ -8,6 +9,7 @@ from torch.utils.dlpack import from_dlpack
class CupyBackend(object): class CupyBackend(object):
def __init__(self): def __init__(self):
pass pass
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
import os import os
from typing import Union from typing import Union
from enum import Enum
import torch import torch
import json import json
...@@ -30,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig ...@@ -30,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config from ..monitor.config import get_monitor_config
from deepspeed import comm as dist from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from ..git_version_info import version as __version__ from ..git_version_info import version as __version__
from ..utils import logger from ..utils import logger
...@@ -71,12 +73,7 @@ ONEBIT_ADAM_OPTIMIZER = 'onebitadam' ...@@ -71,12 +73,7 @@ ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ZERO_ONE_ADAM_OPTIMIZER = 'zerooneadam' ZERO_ONE_ADAM_OPTIMIZER = 'zerooneadam'
ONEBIT_LAMB_OPTIMIZER = 'onebitlamb' ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
DEEPSPEED_OPTIMIZERS = [ DEEPSPEED_OPTIMIZERS = [
ADAGRAD_OPTIMIZER, ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER,
ADAM_OPTIMIZER,
ADAMW_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
ONEBIT_LAMB_OPTIMIZER,
ZERO_ONE_ADAM_OPTIMIZER ZERO_ONE_ADAM_OPTIMIZER
] ]
...@@ -92,11 +89,36 @@ class DeepSpeedConfigError(Exception): ...@@ -92,11 +89,36 @@ class DeepSpeedConfigError(Exception):
pass pass
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
def get_pld_enabled(param_dict): def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys(): if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], PLD_ENABLED, PLD_ENABLED_DEFAULT)
PLD_ENABLED,
PLD_ENABLED_DEFAULT)
else: else:
return False return False
...@@ -136,17 +158,13 @@ def get_fp16_enabled(param_dict): ...@@ -136,17 +158,13 @@ def get_fp16_enabled(param_dict):
def get_bfloat16_enabled(param_dict): def get_bfloat16_enabled(param_dict):
for key in [BFLOAT16, BFLOAT16_OLD]: for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys(): if key in param_dict.keys():
return get_scalar_param(param_dict[key], return get_scalar_param(param_dict[key], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT)
BFLOAT16_ENABLED,
BFLOAT16_ENABLED_DEFAULT)
return False return False
def get_fp16_master_weights_and_grads_enabled(param_dict): def get_fp16_master_weights_and_grads_enabled(param_dict):
if get_fp16_enabled(param_dict): if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
FP16_MASTER_WEIGHTS_AND_GRADS,
FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
else: else:
return False return False
...@@ -158,9 +176,7 @@ def get_fp16_auto_cast(param_dict): ...@@ -158,9 +176,7 @@ def get_fp16_auto_cast(param_dict):
def get_loss_scale(param_dict): def get_loss_scale(param_dict):
if get_fp16_enabled(param_dict): if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT)
FP16_LOSS_SCALE,
FP16_LOSS_SCALE_DEFAULT)
elif get_bfloat16_enabled(param_dict): elif get_bfloat16_enabled(param_dict):
return 1.0 return 1.0
else: else:
...@@ -169,8 +185,7 @@ def get_loss_scale(param_dict): ...@@ -169,8 +185,7 @@ def get_loss_scale(param_dict):
def get_initial_dynamic_scale(param_dict): def get_initial_dynamic_scale(param_dict):
if get_fp16_enabled(param_dict): if get_fp16_enabled(param_dict):
initial_scale_power = get_scalar_param(param_dict[FP16], initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER,
FP16_INITIAL_SCALE_POWER,
FP16_INITIAL_SCALE_POWER_DEFAULT) FP16_INITIAL_SCALE_POWER_DEFAULT)
elif get_bfloat16_enabled(param_dict): elif get_bfloat16_enabled(param_dict):
initial_scale_power = 0 initial_scale_power = 0
...@@ -191,18 +206,10 @@ def get_dynamic_loss_scale_args(param_dict): ...@@ -191,18 +206,10 @@ def get_dynamic_loss_scale_args(param_dict):
FP16_HYSTERESIS, FP16_HYSTERESIS,
] ]
if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args):
init_scale = get_scalar_param(fp16_dict, init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT)
FP16_INITIAL_SCALE_POWER, scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT)
FP16_INITIAL_SCALE_POWER_DEFAULT) delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT)
scale_window = get_scalar_param(fp16_dict, min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT)
FP16_LOSS_SCALE_WINDOW,
FP16_LOSS_SCALE_WINDOW_DEFAULT)
delayed_shift = get_scalar_param(fp16_dict,
FP16_HYSTERESIS,
FP16_HYSTERESIS_DEFAULT)
min_loss_scale = get_scalar_param(fp16_dict,
FP16_MIN_LOSS_SCALE,
FP16_MIN_LOSS_SCALE_DEFAULT)
loss_scale_args = { loss_scale_args = {
INITIAL_LOSS_SCALE: 2**init_scale, INITIAL_LOSS_SCALE: 2**init_scale,
SCALE_WINDOW: scale_window, SCALE_WINDOW: scale_window,
...@@ -214,9 +221,7 @@ def get_dynamic_loss_scale_args(param_dict): ...@@ -214,9 +221,7 @@ def get_dynamic_loss_scale_args(param_dict):
def get_gradient_accumulation_steps(param_dict): def get_gradient_accumulation_steps(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict, GRADIENT_ACCUMULATION_STEPS, GRADIENT_ACCUMULATION_STEPS_DEFAULT)
GRADIENT_ACCUMULATION_STEPS,
GRADIENT_ACCUMULATION_STEPS_DEFAULT)
def get_sparse_gradients_enabled(param_dict): def get_sparse_gradients_enabled(param_dict):
...@@ -224,9 +229,7 @@ def get_sparse_gradients_enabled(param_dict): ...@@ -224,9 +229,7 @@ def get_sparse_gradients_enabled(param_dict):
def get_communication_data_type(param_dict): def get_communication_data_type(param_dict):
val = get_scalar_param(param_dict, val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT)
COMMUNICATION_DATA_TYPE,
COMMUNICATION_DATA_TYPE_DEFAULT)
val = val.lower() if val is not None else val val = val.lower() if val is not None else val
if val is None: if val is None:
return val # we must determine it by other parameters return val # we must determine it by other parameters
...@@ -237,9 +240,7 @@ def get_communication_data_type(param_dict): ...@@ -237,9 +240,7 @@ def get_communication_data_type(param_dict):
elif val == "bfp16": elif val == "bfp16":
return torch.bfloat16 return torch.bfloat16
raise ValueError( raise ValueError(f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}")
f"Invalid communication_data_type. Supported data types: ['fp16', 'bfp16', 'fp32']. Got: {val}"
)
def get_prescale_gradients(param_dict): def get_prescale_gradients(param_dict):
...@@ -247,9 +248,7 @@ def get_prescale_gradients(param_dict): ...@@ -247,9 +248,7 @@ def get_prescale_gradients(param_dict):
def get_gradient_predivide_factor(param_dict): def get_gradient_predivide_factor(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict, GRADIENT_PREDIVIDE_FACTOR, GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
GRADIENT_PREDIVIDE_FACTOR,
GRADIENT_PREDIVIDE_FACTOR_DEFAULT)
def get_steps_per_print(param_dict): def get_steps_per_print(param_dict):
...@@ -284,8 +283,7 @@ def get_sparse_attention(param_dict): ...@@ -284,8 +283,7 @@ def get_sparse_attention(param_dict):
elif mode == SPARSE_BSLONGFORMER_MODE: elif mode == SPARSE_BSLONGFORMER_MODE:
return get_sparse_bslongformer_config(sparsity) return get_sparse_bslongformer_config(sparsity)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Given sparsity mode, {mode}, has not been implemented yet!")
f"Given sparsity mode, {mode}, has not been implemented yet!")
else: else:
return None return None
...@@ -303,15 +301,9 @@ def get_sparse_fixed_config(sparsity): ...@@ -303,15 +301,9 @@ def get_sparse_fixed_config(sparsity):
SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
) )
num_local_blocks = get_scalar_param(sparsity, num_local_blocks = get_scalar_param(sparsity, SPARSE_NUM_LOCAL_BLOCKS, SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
SPARSE_NUM_LOCAL_BLOCKS, num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT) attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT)
num_global_blocks = get_scalar_param(sparsity,
SPARSE_NUM_GLOBAL_BLOCKS,
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
attention = get_scalar_param(sparsity,
SPARSE_ATTENTION_TYPE,
SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param( horizontal_global_attention = get_scalar_param(
sparsity, sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION, SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
...@@ -342,23 +334,15 @@ def get_sparse_variable_config(sparsity): ...@@ -342,23 +334,15 @@ def get_sparse_variable_config(sparsity):
SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
) )
num_random_blocks = get_scalar_param(sparsity, num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
SPARSE_NUM_RANDOM_BLOCKS, local_window_blocks = get_scalar_param(sparsity, SPARSE_LOCAL_WINDOW_BLOCKS, SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
local_window_blocks = get_scalar_param(sparsity,
SPARSE_LOCAL_WINDOW_BLOCKS,
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
global_block_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_INDICES,
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param( global_block_end_indices = get_scalar_param(
sparsity, sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES, SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT, SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT,
) )
attention = get_scalar_param(sparsity, attention = get_scalar_param(sparsity, SPARSE_ATTENTION_TYPE, SPARSE_ATTENTION_TYPE_DEFAULT)
SPARSE_ATTENTION_TYPE,
SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param( horizontal_global_attention = get_scalar_param(
sparsity, sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION, SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
...@@ -385,17 +369,13 @@ def get_sparse_bigbird_config(sparsity): ...@@ -385,17 +369,13 @@ def get_sparse_bigbird_config(sparsity):
SPARSE_DIFFERENT_LAYOUT_PER_HEAD, SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT, SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT,
) )
num_random_blocks = get_scalar_param(sparsity, num_random_blocks = get_scalar_param(sparsity, SPARSE_NUM_RANDOM_BLOCKS, SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
SPARSE_NUM_RANDOM_BLOCKS,
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
num_sliding_window_blocks = get_scalar_param( num_sliding_window_blocks = get_scalar_param(
sparsity, sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS, SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT, SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
) )
num_global_blocks = get_scalar_param(sparsity, num_global_blocks = get_scalar_param(sparsity, SPARSE_NUM_GLOBAL_BLOCKS, SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
SPARSE_NUM_GLOBAL_BLOCKS,
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
return { return {
SPARSE_MODE: SPARSE_BIGBIRD_MODE, SPARSE_MODE: SPARSE_BIGBIRD_MODE,
...@@ -419,9 +399,7 @@ def get_sparse_bslongformer_config(sparsity): ...@@ -419,9 +399,7 @@ def get_sparse_bslongformer_config(sparsity):
SPARSE_NUM_SLIDING_WINDOW_BLOCKS, SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT, SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT,
) )
global_block_indices = get_scalar_param(sparsity, global_block_indices = get_scalar_param(sparsity, SPARSE_GLOBAL_BLOCK_INDICES, SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
SPARSE_GLOBAL_BLOCK_INDICES,
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param( global_block_end_indices = get_scalar_param(
sparsity, sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES, SPARSE_GLOBAL_BLOCK_END_INDICES,
...@@ -474,8 +452,7 @@ def get_optimizer_name(param_dict): ...@@ -474,8 +452,7 @@ def get_optimizer_name(param_dict):
def get_optimizer_params(param_dict): def get_optimizer_params(param_dict):
if (get_optimizer_name(param_dict) is not None if (get_optimizer_name(param_dict) is not None and OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys()):
and OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys()):
return param_dict[OPTIMIZER][OPTIMIZER_PARAMS] return param_dict[OPTIMIZER][OPTIMIZER_PARAMS]
else: else:
return None return None
...@@ -497,9 +474,11 @@ def get_optimizer_legacy_fusion(param_dict): ...@@ -497,9 +474,11 @@ def get_optimizer_legacy_fusion(param_dict):
def get_zero_allow_untested_optimizer(param_dict): def get_zero_allow_untested_optimizer(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict, ZERO_ALLOW_UNTESTED_OPTIMIZER, ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)
ZERO_ALLOW_UNTESTED_OPTIMIZER,
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)
def get_zero_force_ds_cpu_optimizer(param_dict):
return get_scalar_param(param_dict, ZERO_FORCE_DS_CPU_OPTIMIZER, ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT)
def get_scheduler_name(param_dict): def get_scheduler_name(param_dict):
...@@ -510,8 +489,7 @@ def get_scheduler_name(param_dict): ...@@ -510,8 +489,7 @@ def get_scheduler_name(param_dict):
def get_scheduler_params(param_dict): def get_scheduler_params(param_dict):
if (get_scheduler_name(param_dict) is not None if (get_scheduler_name(param_dict) is not None and SCHEDULER_PARAMS in param_dict[SCHEDULER].keys()):
and SCHEDULER_PARAMS in param_dict[SCHEDULER].keys()):
return param_dict[SCHEDULER][SCHEDULER_PARAMS] return param_dict[SCHEDULER][SCHEDULER_PARAMS]
else: else:
return None return None
...@@ -530,15 +508,28 @@ def get_train_micro_batch_size_per_gpu(param_dict): ...@@ -530,15 +508,28 @@ def get_train_micro_batch_size_per_gpu(param_dict):
def get_wall_clock_breakdown(param_dict): def get_wall_clock_breakdown(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict, WALL_CLOCK_BREAKDOWN, WALL_CLOCK_BREAKDOWN_DEFAULT)
WALL_CLOCK_BREAKDOWN,
WALL_CLOCK_BREAKDOWN_DEFAULT)
def get_memory_breakdown(param_dict): def get_memory_breakdown(param_dict):
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT) return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
class HybridEngineConfig(DeepSpeedConfigModel):
enabled: bool = False
max_out_tokens: int = 512
inference_tp_size: int = 1
release_inference_cache: bool = False
pin_parameters: bool = True
tp_gather_partition_size: int = 8
def get_hybrid_engine_config(param_dict):
hybrid_engine_config_dict = param_dict.get("hybrid_engine", {})
hybrid_engine_config = HybridEngineConfig(**hybrid_engine_config_dict)
return hybrid_engine_config
def get_eigenvalue_config(param_dict): def get_eigenvalue_config(param_dict):
if get_quantize_enabled(param_dict): if get_quantize_enabled(param_dict):
param_dict = param_dict[QUANTIZE_TRAINING] param_dict = param_dict[QUANTIZE_TRAINING]
...@@ -568,45 +559,35 @@ def get_eigenvalue_config(param_dict): ...@@ -568,45 +559,35 @@ def get_eigenvalue_config(param_dict):
def get_eigenvalue_enabled(param_dict): def get_eigenvalue_enabled(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_ENABLED, EIGENVALUE_ENABLED_DEFAULT)
EIGENVALUE_ENABLED,
EIGENVALUE_ENABLED_DEFAULT)
else: else:
return EIGENVALUE_ENABLED_DEFAULT return EIGENVALUE_ENABLED_DEFAULT
def get_eigenvalue_verbose(param_dict): def get_eigenvalue_verbose(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_VERBOSE, EIGENVALUE_VERBOSE_DEFAULT)
EIGENVALUE_VERBOSE,
EIGENVALUE_VERBOSE_DEFAULT)
else: else:
return EIGENVALUE_VERBOSE_DEFAULT return EIGENVALUE_VERBOSE_DEFAULT
def get_eigenvalue_max_iter(param_dict): def get_eigenvalue_max_iter(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_MAX_ITER, EIGENVALUE_MAX_ITER_DEFAULT)
EIGENVALUE_MAX_ITER,
EIGENVALUE_MAX_ITER_DEFAULT)
else: else:
return EIGENVALUE_MAX_ITER_DEFAULT return EIGENVALUE_MAX_ITER_DEFAULT
def get_eigenvalue_tol(param_dict): def get_eigenvalue_tol(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_TOL, EIGENVALUE_TOL_DEFAULT)
EIGENVALUE_TOL,
EIGENVALUE_TOL_DEFAULT)
else: else:
return EIGENVALUE_TOL_DEFAULT return EIGENVALUE_TOL_DEFAULT
def get_eigenvalue_stability(param_dict): def get_eigenvalue_stability(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_STABILITY, EIGENVALUE_STABILITY_DEFAULT)
EIGENVALUE_STABILITY,
EIGENVALUE_STABILITY_DEFAULT)
else: else:
return EIGENVALUE_STABILITY_DEFAULT return EIGENVALUE_STABILITY_DEFAULT
...@@ -624,18 +605,14 @@ def get_eigenvalue_gas_boundary_resolution(param_dict): ...@@ -624,18 +605,14 @@ def get_eigenvalue_gas_boundary_resolution(param_dict):
def get_eigenvalue_layer_name(param_dict): def get_eigenvalue_layer_name(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_LAYER_NAME, EIGENVALUE_LAYER_NAME_DEFAULT)
EIGENVALUE_LAYER_NAME,
EIGENVALUE_LAYER_NAME_DEFAULT)
else: else:
return EIGENVALUE_LAYER_NAME_DEFAULT return EIGENVALUE_LAYER_NAME_DEFAULT
def get_eigenvalue_layer_num(param_dict): def get_eigenvalue_layer_num(param_dict):
if EIGENVALUE in param_dict.keys(): if EIGENVALUE in param_dict.keys():
return get_scalar_param(param_dict[EIGENVALUE], return get_scalar_param(param_dict[EIGENVALUE], EIGENVALUE_LAYER_NUM, EIGENVALUE_LAYER_NUM_DEFAULT)
EIGENVALUE_LAYER_NUM,
EIGENVALUE_LAYER_NUM_DEFAULT)
else: else:
return EIGENVALUE_LAYER_NUM_DEFAULT return EIGENVALUE_LAYER_NUM_DEFAULT
...@@ -649,35 +626,29 @@ def get_data_types_params(param_dict): ...@@ -649,35 +626,29 @@ def get_data_types_params(param_dict):
def get_checkpoint_tag_validation_mode(checkpoint_params): def get_checkpoint_tag_validation_mode(checkpoint_params):
tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION, tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT)
CHECKPOINT_TAG_VALIDATION_DEFAULT)
tag_validation_mode = tag_validation_mode.upper() tag_validation_mode = tag_validation_mode.upper()
if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES: if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES:
return tag_validation_mode return tag_validation_mode
else: else:
raise DeepSpeedConfigError( raise DeepSpeedConfigError(
"Checkpoint config contains invalid tag_validation " "Checkpoint config contains invalid tag_validation "
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}" f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
)
def get_checkpoint_parallel_write_pipeline(checkpoint_params): def get_checkpoint_parallel_write_pipeline(checkpoint_params):
par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {}) par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {})
par_write_pipeline = par_write_params.get( par_write_pipeline = par_write_params.get(CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE,
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE, CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
if par_write_pipeline in [True, False]: if par_write_pipeline in [True, False]:
return par_write_pipeline return par_write_pipeline
else: else:
raise DeepSpeedConfigError( raise DeepSpeedConfigError("checkpoint::parallel_write::pipeline_stage "
"checkpoint::parallel_write::pipeline_stage " f"value of '{par_write_pipeline}' is invalid, expecting: true or false")
f"value of '{par_write_pipeline}' is invalid, expecting: true or false")
def get_dataloader_drop_last(param_dict): def get_dataloader_drop_last(param_dict):
return get_scalar_param(param_dict, return get_scalar_param(param_dict, DATALOADER_DROP_LAST, DATALOADER_DROP_LAST_DEFAULT)
DATALOADER_DROP_LAST,
DATALOADER_DROP_LAST_DEFAULT)
'''Write deepspeed config files by modifying basic templates. '''Write deepspeed config files by modifying basic templates.
...@@ -685,6 +656,7 @@ Can be used for quickly changing parameters via command line parameters.''' ...@@ -685,6 +656,7 @@ Can be used for quickly changing parameters via command line parameters.'''
class DeepSpeedConfigWriter: class DeepSpeedConfigWriter:
def __init__(self, data=None): def __init__(self, data=None):
self.data = data if data is not None else {} self.data = data if data is not None else {}
...@@ -692,9 +664,7 @@ class DeepSpeedConfigWriter: ...@@ -692,9 +664,7 @@ class DeepSpeedConfigWriter:
self.data[key] = value self.data[key] = value
def load_config(self, filename): def load_config(self, filename):
self.data = json.load(open(filename, self.data = json.load(open(filename, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
def write_config(self, filename): def write_config(self, filename):
with open(filename, "w") as outfile: with open(filename, "w") as outfile:
...@@ -702,15 +672,13 @@ class DeepSpeedConfigWriter: ...@@ -702,15 +672,13 @@ class DeepSpeedConfigWriter:
class DeepSpeedConfig(object): class DeepSpeedConfig(object):
def __init__(self, config: Union[str, dict], mpu=None): def __init__(self, config: Union[str, dict], mpu=None):
super(DeepSpeedConfig, self).__init__() super(DeepSpeedConfig, self).__init__()
if isinstance(config, dict): if isinstance(config, dict):
self._param_dict = config self._param_dict = config
elif os.path.exists(config): elif os.path.exists(config):
self._param_dict = hjson.load( self._param_dict = hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
open(config,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else: else:
try: try:
config_decoded = base64.urlsafe_b64decode(config).decode('utf-8') config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
...@@ -744,24 +712,18 @@ class DeepSpeedConfig(object): ...@@ -744,24 +712,18 @@ class DeepSpeedConfig(object):
# Ensure the resource scheduler saw the same elastic config we are using at runtime # Ensure the resource scheduler saw the same elastic config we are using at runtime
ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict)
self.elastic_model_parallel_size = elastic_dict.get( self.elastic_model_parallel_size = elastic_dict.get(MODEL_PARLLEL_SIZE, MODEL_PARLLEL_SIZE_DEFAULT)
MODEL_PARLLEL_SIZE,
MODEL_PARLLEL_SIZE_DEFAULT)
if self.elastic_model_parallel_size < 1: if self.elastic_model_parallel_size < 1:
raise ElasticityConfigError( raise ElasticityConfigError("Model-Parallel size cannot be less than 1, "
"Model-Parallel size cannot be less than 1, " f"given model-parallel size: {self.elastic_model_parallel_size}")
f"given model-parallel size: {self.elastic_model_parallel_size}")
self.num_gpus_per_node = elastic_dict.get(NUM_GPUS_PER_NODE, self.num_gpus_per_node = elastic_dict.get(NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT)
NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1: if self.num_gpus_per_node < 1:
raise ElasticityConfigError( raise ElasticityConfigError("NUmber of GPUs per node cannot be less than 1, "
"NUmber of GPUs per node cannot be less than 1, " f"given number of GPUs per node: {self.num_gpus_per_node}")
f"given number of GPUs per node: {self.num_gpus_per_node}")
ignore_non_elastic_batch_info = elastic_dict.get( ignore_non_elastic_batch_info = elastic_dict.get(IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO, IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
if not ignore_non_elastic_batch_info: if not ignore_non_elastic_batch_info:
batch_params = [ batch_params = [
...@@ -779,23 +741,17 @@ class DeepSpeedConfig(object): ...@@ -779,23 +741,17 @@ class DeepSpeedConfig(object):
# micro_bsz * world_size * gas = total_batch_size # micro_bsz * world_size * gas = total_batch_size
# gas = total_batch_size // (micro_bsz * world_size) # gas = total_batch_size // (micro_bsz * world_size)
gradient_accu_steps = final_batch_size // (micro_batch_size * gradient_accu_steps = final_batch_size // (micro_batch_size * self.world_size)
self.world_size)
if TRAIN_BATCH_SIZE in self._param_dict: if TRAIN_BATCH_SIZE in self._param_dict:
logger.warning( logger.warning("[Elasticity] overriding training_batch_size: "
"[Elasticity] overriding training_batch_size: " f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict: if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict:
logger.warning( logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: "
"[Elasticity] overriding train_micro_batch_size_per_gpu: " f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}")
f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}"
)
if GRADIENT_ACCUMULATION_STEPS in self._param_dict: if GRADIENT_ACCUMULATION_STEPS in self._param_dict:
logger.warning( logger.warning("[Elasticity] overriding gradient_accumulation_steps: "
"[Elasticity] overriding gradient_accumulation_steps: " f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}")
f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}"
)
logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}") logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}")
...@@ -811,8 +767,7 @@ class DeepSpeedConfig(object): ...@@ -811,8 +767,7 @@ class DeepSpeedConfig(object):
def _initialize_params(self, param_dict): def _initialize_params(self, param_dict):
self.train_batch_size = get_train_batch_size(param_dict) self.train_batch_size = get_train_batch_size(param_dict)
#print(f"beginning get_train_batch_size = {get_train_batch_size}") #print(f"beginning get_train_batch_size = {get_train_batch_size}")
self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu( self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict)
param_dict)
self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict) self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict)
self.steps_per_print = get_steps_per_print(param_dict) self.steps_per_print = get_steps_per_print(param_dict)
self.dump_state = get_dump_state(param_dict) self.dump_state = get_dump_state(param_dict)
...@@ -824,11 +779,12 @@ class DeepSpeedConfig(object): ...@@ -824,11 +779,12 @@ class DeepSpeedConfig(object):
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
self.zero_config = get_zero_config(param_dict) self.zero_config = get_zero_config(param_dict)
self.mics_shard_size = self.zero_config.mics_shard_size
self.mics_hierarchial_params_gather = self.zero_config.mics_hierarchical_params_gather
self.zero_optimization_stage = self.zero_config.stage self.zero_optimization_stage = self.zero_config.stage
self.zero_enabled = self.zero_optimization_stage > 0 self.zero_enabled = self.zero_optimization_stage > 0
self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig( self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(param_dict)
param_dict)
self.comms_config = DeepSpeedCommsConfig(param_dict) self.comms_config = DeepSpeedCommsConfig(param_dict)
self.monitor_config = get_monitor_config(param_dict) self.monitor_config = get_monitor_config(param_dict)
...@@ -837,9 +793,9 @@ class DeepSpeedConfig(object): ...@@ -837,9 +793,9 @@ class DeepSpeedConfig(object):
self.fp16_enabled = get_fp16_enabled(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict)
self.fp16_auto_cast = get_fp16_auto_cast(param_dict) self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' assert not (self.fp16_enabled
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
param_dict) self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
self.amp_enabled = get_amp_enabled(param_dict) self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict) self.amp_params = get_amp_params(param_dict)
self.loss_scale = get_loss_scale(param_dict) self.loss_scale = get_loss_scale(param_dict)
...@@ -849,22 +805,21 @@ class DeepSpeedConfig(object): ...@@ -849,22 +805,21 @@ class DeepSpeedConfig(object):
self.compression_config = get_compression_config(param_dict) self.compression_config = get_compression_config(param_dict)
self.optimizer_name = get_optimizer_name(param_dict) self.optimizer_name = get_optimizer_name(param_dict)
if (self.optimizer_name is not None if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS):
and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS):
self.optimizer_name = self.optimizer_name.lower() self.optimizer_name = self.optimizer_name.lower()
self.optimizer_params = get_optimizer_params(param_dict) self.optimizer_params = get_optimizer_params(param_dict)
self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict) self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict)
self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer( self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(param_dict)
param_dict)
self.zero_force_ds_cpu_optimizer = get_zero_force_ds_cpu_optimizer(param_dict)
self.scheduler_name = get_scheduler_name(param_dict) self.scheduler_name = get_scheduler_name(param_dict)
self.scheduler_params = get_scheduler_params(param_dict) self.scheduler_params = get_scheduler_params(param_dict)
self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict) self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict)
self.wall_clock_breakdown = (get_wall_clock_breakdown(param_dict) self.wall_clock_breakdown = (get_wall_clock_breakdown(param_dict) | self.flops_profiler_config.enabled)
| self.flops_profiler_config.enabled)
self.memory_breakdown = get_memory_breakdown(param_dict) self.memory_breakdown = get_memory_breakdown(param_dict)
self.autotuning_config = DeepSpeedAutotuningConfig(param_dict) self.autotuning_config = DeepSpeedAutotuningConfig(param_dict)
...@@ -879,6 +834,8 @@ class DeepSpeedConfig(object): ...@@ -879,6 +834,8 @@ class DeepSpeedConfig(object):
self.eigenvalue_layer_num, self.eigenvalue_layer_num,
) = get_eigenvalue_config(param_dict) ) = get_eigenvalue_config(param_dict)
self.hybrid_engine = get_hybrid_engine_config(param_dict)
self.sparse_attention = get_sparse_attention(param_dict) self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict) self.pipeline = get_pipeline_config(param_dict)
...@@ -893,20 +850,16 @@ class DeepSpeedConfig(object): ...@@ -893,20 +850,16 @@ class DeepSpeedConfig(object):
checkpoint_params = get_checkpoint_params(param_dict) checkpoint_params = get_checkpoint_params(param_dict)
validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params) validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params)
self.checkpoint_tag_validation_enabled = (validation_mode != self.checkpoint_tag_validation_enabled = (validation_mode != ValidationMode.IGNORE)
ValidationMode.IGNORE)
self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL
self.load_universal_checkpoint = checkpoint_params.get( self.load_universal_checkpoint = checkpoint_params.get(LOAD_UNIVERSAL_CHECKPOINT,
LOAD_UNIVERSAL_CHECKPOINT, LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
self.use_node_local_storage = checkpoint_params.get( self.use_node_local_storage = checkpoint_params.get(USE_NODE_LOCAL_STORAGE_CHECKPOINT,
USE_NODE_LOCAL_STORAGE_CHECKPOINT, USE_NODE_LOCAL_STORAGE_CHECKPOINT_DEFAULT)
USE_NODE_LOCAL_STORAGE_CHECKPOINT_DEFAULT)
data_types_params = get_data_types_params(param_dict) data_types_params = get_data_types_params(param_dict)
self.grad_accum_dtype = data_types_params.get(GRAD_ACCUM_DTYPE, self.grad_accum_dtype = data_types_params.get(GRAD_ACCUM_DTYPE, GRAD_ACCUM_DTYPE_DEFAULT)
GRAD_ACCUM_DTYPE_DEFAULT)
par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params) par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
self.checkpoint_parallel_write_pipeline = par_write_pipe self.checkpoint_parallel_write_pipeline = par_write_pipe
...@@ -923,23 +876,16 @@ class DeepSpeedConfig(object): ...@@ -923,23 +876,16 @@ class DeepSpeedConfig(object):
micro_batch = self.train_micro_batch_size_per_gpu micro_batch = self.train_micro_batch_size_per_gpu
grad_acc = self.gradient_accumulation_steps grad_acc = self.gradient_accumulation_steps
assert ( assert (train_batch > 0), f"Train batch size: {train_batch} has to be greater than 0"
train_batch > 0
), f"Train batch size: {train_batch} has to be greater than 0"
assert ( assert (micro_batch > 0), f"Micro batch size per gpu: {micro_batch} has to be greater than 0"
micro_batch > 0
), f"Micro batch size per gpu: {micro_batch} has to be greater than 0"
assert ( assert (grad_acc > 0), f"Gradient accumulation steps: {grad_acc} has to be greater than 0"
grad_acc > 0
), f"Gradient accumulation steps: {grad_acc} has to be greater than 0"
assert train_batch == micro_batch * grad_acc * self.world_size, ( assert train_batch == micro_batch * grad_acc * self.world_size, (
f"Check batch related parameters. train_batch_size is not equal " f"Check batch related parameters. train_batch_size is not equal "
"to micro_batch_per_gpu * gradient_acc_step * world_size " "to micro_batch_per_gpu * gradient_acc_step * world_size "
f"{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}" f"{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}")
)
def _set_batch_related_parameters(self): def _set_batch_related_parameters(self):
...@@ -1002,8 +948,7 @@ class DeepSpeedConfig(object): ...@@ -1002,8 +948,7 @@ class DeepSpeedConfig(object):
sort_keys=True, sort_keys=True,
indent=4, indent=4,
cls=ScientificNotationEncoder, cls=ScientificNotationEncoder,
separators=(",", separators=(",", ":"),
":"),
))) )))
def print(self, name): def print(self, name):
...@@ -1016,20 +961,16 @@ class DeepSpeedConfig(object): ...@@ -1016,20 +961,16 @@ class DeepSpeedConfig(object):
self.print_user_config() self.print_user_config()
def _do_error_check(self): def _do_error_check(self):
assert ( assert (self.train_micro_batch_size_per_gpu
self.train_micro_batch_size_per_gpu ), "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
), "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
assert ( assert (
self.gradient_accumulation_steps self.gradient_accumulation_steps), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS)
), "DeepSpeedConfig: {} is not defined".format(GRADIENT_ACCUMULATION_STEPS)
if self.zero_enabled: if self.zero_enabled:
assert ( assert (self.zero_optimization_stage <=
self.zero_optimization_stage <= ZeroStageEnum.max_stage ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( ZeroStageEnum.max_stage)
ZeroStageEnum.max_stage
)
if self.fp16_master_weights_and_gradients: if self.fp16_master_weights_and_gradients:
assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now."
...@@ -1040,19 +981,15 @@ class DeepSpeedConfig(object): ...@@ -1040,19 +981,15 @@ class DeepSpeedConfig(object):
vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
logger.warning( logger.warning(
"DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization." "DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization.".format(
.format(vocabulary_size, vocabulary_size, TENSOR_CORE_ALIGN_SIZE))
TENSOR_CORE_ALIGN_SIZE))
if (self.optimizer_params is not None if (self.optimizer_params is not None and MAX_GRAD_NORM in self.optimizer_params.keys()
and MAX_GRAD_NORM in self.optimizer_params.keys()
and self.optimizer_params[MAX_GRAD_NORM] > 0): and self.optimizer_params[MAX_GRAD_NORM] > 0):
if fp16_enabled: if fp16_enabled:
if self.global_rank == 0: if self.global_rank == 0:
logger.warning( logger.warning("DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper".format(
"DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper" MAX_GRAD_NORM, self.optimizer_params[MAX_GRAD_NORM]))
.format(MAX_GRAD_NORM,
self.optimizer_params[MAX_GRAD_NORM]))
else: else:
if self.global_rank == 0: if self.global_rank == 0:
logger.warning( logger.warning(
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
""" """
Collection of DeepSpeed configuration utilities Collection of DeepSpeed configuration utilities
""" """
...@@ -50,15 +49,10 @@ class DeepSpeedConfigModel(BaseModel): ...@@ -50,15 +49,10 @@ class DeepSpeedConfigModel(BaseModel):
new_param='my_new_field', new_param='my_new_field',
new_param_fn=(lambda x: int(x))) new_param_fn=(lambda x: int(x)))
""" """
def __init__(self, strict=False, **data): def __init__(self, strict=False, **data):
if ( if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
not strict data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
): # This is temporary until we refactor all DS configs, allows HF to load models
data = {
k: v
for k,
v in data.items() if (v != "auto" or k == "replace_method")
}
super().__init__(**data) super().__init__(**data)
self._deprecated_fields_check(self) self._deprecated_fields_check(self)
...@@ -73,8 +67,7 @@ class DeepSpeedConfigModel(BaseModel): ...@@ -73,8 +67,7 @@ class DeepSpeedConfigModel(BaseModel):
dep_msg = kwargs.get("deprecated_msg", "") dep_msg = kwargs.get("deprecated_msg", "")
if dep_param in fields_set: if dep_param in fields_set:
logger.warning(f"Config parameter {dep_param} is deprecated" + logger.warning(f"Config parameter {dep_param} is deprecated" +
(f" use {new_param} instead" if new_param else "") + (f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
(f". {dep_msg}" if dep_msg else ""))
# Check if there is a new param and if it should be set with a value # Check if there is a new param and if it should be set with a value
if new_param and kwargs.get("set_new_param", True): if new_param and kwargs.get("set_new_param", True):
# Remove the deprecate field if there is a replacing field # Remove the deprecate field if there is a replacing field
...@@ -89,9 +82,7 @@ class DeepSpeedConfigModel(BaseModel): ...@@ -89,9 +82,7 @@ class DeepSpeedConfigModel(BaseModel):
if len(new_param_nested) > 1: if len(new_param_nested) > 1:
# If the new param exists in a subconfig, we need to get # If the new param exists in a subconfig, we need to get
# the fields set for that subconfig # the fields set for that subconfig
pydantic_config = reduce(getattr, pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
new_param_nested[:-1],
pydantic_config)
fields_set = pydantic_config.__fields_set__ fields_set = pydantic_config.__fields_set__
new_param_name = new_param_nested[-1] new_param_name = new_param_nested[-1]
assert ( assert (
...@@ -101,9 +92,7 @@ class DeepSpeedConfigModel(BaseModel): ...@@ -101,9 +92,7 @@ class DeepSpeedConfigModel(BaseModel):
try: try:
setattr(pydantic_config, new_param_name, param_value) setattr(pydantic_config, new_param_name, param_value)
except Exception as e: except Exception as e:
logger.error( logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'"
)
raise e raise e
def _deprecated_fields_check(self, pydantic_config): def _deprecated_fields_check(self, pydantic_config):
...@@ -121,12 +110,20 @@ class DeepSpeedConfigModel(BaseModel): ...@@ -121,12 +110,20 @@ class DeepSpeedConfigModel(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def get_config_default(config, field_name):
assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
assert not config.__fields__.get(
field_name).required, f"'{field_name}' is a required field and does not have a default value"
return config.__fields__.get(field_name).default
class pp_int(int): class pp_int(int):
""" """
A wrapper for integers that will return a custom string or comma-formatted A wrapper for integers that will return a custom string or comma-formatted
string of the integer. For example, print(pp_int(1e5)) will return string of the integer. For example, print(pp_int(1e5)) will return
"10,000". This is useful mainly for auto-generated documentation purposes. "10,000". This is useful mainly for auto-generated documentation purposes.
""" """
def __new__(cls, val, custom_print_str=None): def __new__(cls, val, custom_print_str=None):
inst = super().__new__(cls, val) inst = super().__new__(cls, val)
inst.custom_print_str = custom_print_str inst.custom_print_str = custom_print_str
...@@ -148,6 +145,7 @@ class ScientificNotationEncoder(json.JSONEncoder): ...@@ -148,6 +145,7 @@ class ScientificNotationEncoder(json.JSONEncoder):
Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
""" """
def iterencode(self, o, _one_shot=False, level=0): def iterencode(self, o, _one_shot=False, level=0):
indent = self.indent if self.indent is not None else 4 indent = self.indent if self.indent is not None else 4
prefix_close = " " * level * indent prefix_close = " " * level * indent
...@@ -161,10 +159,7 @@ class ScientificNotationEncoder(json.JSONEncoder): ...@@ -161,10 +159,7 @@ class ScientificNotationEncoder(json.JSONEncoder):
else: else:
return f"{o}" return f"{o}"
elif isinstance(o, collections.abc.Mapping): elif isinstance(o, collections.abc.Mapping):
x = [ x = [f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, v in o.items()]
f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k,
v in o.items()
]
return "{" + ", ".join(x) + f"\n{prefix_close}" + "}" return "{" + ", ".join(x) + f"\n{prefix_close}" + "}"
elif isinstance(o, collections.abc.Sequence) and not isinstance(o, str): elif isinstance(o, collections.abc.Sequence) and not isinstance(o, str):
return f"[{ f', '.join(map(self.iterencode, o)) }]" return f"[{ f', '.join(map(self.iterencode, o)) }]"
...@@ -175,6 +170,7 @@ class DeepSpeedConfigObject(object): ...@@ -175,6 +170,7 @@ class DeepSpeedConfigObject(object):
""" """
For json serialization For json serialization
""" """
def repr(self): def repr(self):
return self.__dict__ return self.__dict__
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
############################################# #############################################
# Routes # Routes
...@@ -73,6 +72,8 @@ MAX_GRAD_NORM = 'max_grad_norm' ...@@ -73,6 +72,8 @@ MAX_GRAD_NORM = 'max_grad_norm'
############################################# #############################################
ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer" ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer"
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False
ZERO_FORCE_DS_CPU_OPTIMIZER = "zero_force_ds_cpu_optimizer"
ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT = True
# Steps # Steps
STEPS_PER_PRINT = "steps_per_print" STEPS_PER_PRINT = "steps_per_print"
...@@ -368,11 +369,7 @@ class ValidationMode: ...@@ -368,11 +369,7 @@ class ValidationMode:
CHECKPOINT = "checkpoint" CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation" CHECKPOINT_TAG_VALIDATION = "tag_validation"
CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN
CHECKPOINT_TAG_VALIDATION_MODES = [ CHECKPOINT_TAG_VALIDATION_MODES = [ValidationMode.WARN, ValidationMode.IGNORE, ValidationMode.FAIL]
ValidationMode.WARN,
ValidationMode.IGNORE,
ValidationMode.FAIL
]
LOAD_UNIVERSAL_CHECKPOINT = "load_universal" LOAD_UNIVERSAL_CHECKPOINT = "load_universal"
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False
......
"""
Copyright 2020 The Microsoft DeepSpeed Team
Implementation of a compressed sparse row (CSR) tensor. Similar in
functionality to TensorFlow's IndexedSlices implementation.
"""
import torch
class CSRTensor(object):
""" Compressed Sparse Row (CSR) Tensor """
def __init__(self, dense_tensor=None):
self.orig_dense_tensor = dense_tensor
if dense_tensor is not None:
result = torch.sum(dense_tensor, dim=1)
self.indices = result.nonzero().flatten()
self.values = dense_tensor[self.indices]
self.dense_size = list(dense_tensor.size())
else:
self.indices = None
self.values = None
self.dense_size = None
@staticmethod
def type():
return "deepspeed.CSRTensor"
def to_dense(self):
it = self.indices.unsqueeze(1)
full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
return self.values.new_zeros(self.dense_size).scatter_add_(
0,
full_indices,
self.values)
def sparse_size(self):
index_size = list(self.indices.size())
index_size = index_size[0]
value_size = list(self.values.size())
value_size = value_size[0] * value_size[1]
dense_size = self.dense_size[0] * self.dense_size[1]
return index_size + value_size, dense_size
def add(self, b):
assert self.dense_size == b.dense_size
self.indices = torch.cat([self.indices, b.indices])
self.values = torch.cat([self.values, b.values])
def __str__(self):
sparse_size, dense_size = self.sparse_size()
return "DeepSpeed.CSRTensor(indices_size={}, values_size={}, " \
"dense_size={}, device={}, reduction_factor={})".format(
self.indices.size(), self.values.size(), self.dense_size,
self.indices.get_device(), dense_size / sparse_size
)
def __repr__(self):
return self.__str__()
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from .constants import * from .constants import *
import copy import copy
from ..config_utils import get_scalar_param from ..config_utils import get_scalar_param
...@@ -24,18 +26,14 @@ def get_data_efficiency_config(param_dict): ...@@ -24,18 +26,14 @@ def get_data_efficiency_config(param_dict):
def get_data_efficiency_enabled(param_dict): def get_data_efficiency_enabled(param_dict):
if DATA_EFFICIENCY in param_dict.keys(): if DATA_EFFICIENCY in param_dict.keys():
return get_scalar_param(param_dict[DATA_EFFICIENCY], return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_ENABLED, DATA_EFFICIENCY_ENABLED_DEFAULT)
DATA_EFFICIENCY_ENABLED,
DATA_EFFICIENCY_ENABLED_DEFAULT)
else: else:
return False return False
def get_data_efficiency_seed(param_dict): def get_data_efficiency_seed(param_dict):
if DATA_EFFICIENCY in param_dict.keys(): if DATA_EFFICIENCY in param_dict.keys():
return get_scalar_param(param_dict[DATA_EFFICIENCY], return get_scalar_param(param_dict[DATA_EFFICIENCY], DATA_EFFICIENCY_SEED, DATA_EFFICIENCY_SEED_DEFAULT)
DATA_EFFICIENCY_SEED,
DATA_EFFICIENCY_SEED_DEFAULT)
else: else:
return DATA_EFFICIENCY_SEED_DEFAULT return DATA_EFFICIENCY_SEED_DEFAULT
...@@ -55,26 +53,21 @@ def get_data_sampling(param_dict): ...@@ -55,26 +53,21 @@ def get_data_sampling(param_dict):
def get_data_sampling_enabled(param_dict): def get_data_sampling_enabled(param_dict):
if DATA_SAMPLING in param_dict.keys(): if DATA_SAMPLING in param_dict.keys():
return get_scalar_param(param_dict[DATA_SAMPLING], return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_ENABLED, DATA_SAMPLING_ENABLED_DEFAULT)
DATA_SAMPLING_ENABLED,
DATA_SAMPLING_ENABLED_DEFAULT)
else: else:
return False return False
def get_data_sampling_num_epochs(param_dict): def get_data_sampling_num_epochs(param_dict):
if DATA_SAMPLING in param_dict.keys(): if DATA_SAMPLING in param_dict.keys():
return get_scalar_param(param_dict[DATA_SAMPLING], return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_EPOCHS, DATA_SAMPLING_NUM_EPOCHS_DEFAULT)
DATA_SAMPLING_NUM_EPOCHS,
DATA_SAMPLING_NUM_EPOCHS_DEFAULT)
else: else:
return DATA_SAMPLING_NUM_EPOCHS_DEFAULT return DATA_SAMPLING_NUM_EPOCHS_DEFAULT
def get_data_sampling_num_workers(param_dict): def get_data_sampling_num_workers(param_dict):
if DATA_SAMPLING in param_dict.keys(): if DATA_SAMPLING in param_dict.keys():
return get_scalar_param(param_dict[DATA_SAMPLING], return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_NUM_WORKERS,
DATA_SAMPLING_NUM_WORKERS,
DATA_SAMPLING_NUM_WORKERS_DEFAULT) DATA_SAMPLING_NUM_WORKERS_DEFAULT)
else: else:
return DATA_SAMPLING_NUM_WORKERS_DEFAULT return DATA_SAMPLING_NUM_WORKERS_DEFAULT
...@@ -87,7 +80,8 @@ def get_curriculum_learning(param_dict): ...@@ -87,7 +80,8 @@ def get_curriculum_learning(param_dict):
param_dict[CURRICULUM_LEARNING] = {} param_dict[CURRICULUM_LEARNING] = {}
sub_param_dict = param_dict[CURRICULUM_LEARNING] sub_param_dict = param_dict[CURRICULUM_LEARNING]
if output[CURRICULUM_LEARNING_ENABLED]: if output[CURRICULUM_LEARNING_ENABLED]:
assert CURRICULUM_LEARNING_METRICS in sub_param_dict.keys(), f"Curriculum learning is enabled, {CURRICULUM_LEARNING_METRICS} must be specified" assert CURRICULUM_LEARNING_METRICS in sub_param_dict.keys(
), f"Curriculum learning is enabled, {CURRICULUM_LEARNING_METRICS} must be specified"
for key, val in get_curriculum_learning_params(param_dict).items(): for key, val in get_curriculum_learning_params(param_dict).items():
output[key] = val output[key] = val
return output return output
...@@ -95,8 +89,7 @@ def get_curriculum_learning(param_dict): ...@@ -95,8 +89,7 @@ def get_curriculum_learning(param_dict):
def get_curriculum_learning_enabled(param_dict): def get_curriculum_learning_enabled(param_dict):
if CURRICULUM_LEARNING in param_dict.keys(): if CURRICULUM_LEARNING in param_dict.keys():
return get_scalar_param(param_dict[CURRICULUM_LEARNING], return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED,
CURRICULUM_LEARNING_ENABLED,
CURRICULUM_LEARNING_ENABLED_DEFAULT) CURRICULUM_LEARNING_ENABLED_DEFAULT)
else: else:
return False return False
...@@ -113,8 +106,7 @@ def get_curriculum_learning_params(param_dict): ...@@ -113,8 +106,7 @@ def get_curriculum_learning_params(param_dict):
def get_curriculum_enabled_legacy(param_dict): def get_curriculum_enabled_legacy(param_dict):
if CURRICULUM_LEARNING_LEGACY in param_dict.keys(): if CURRICULUM_LEARNING_LEGACY in param_dict.keys():
return get_scalar_param(param_dict[CURRICULUM_LEARNING_LEGACY], return get_scalar_param(param_dict[CURRICULUM_LEARNING_LEGACY], CURRICULUM_ENABLED_LEGACY,
CURRICULUM_ENABLED_LEGACY,
CURRICULUM_ENABLED_DEFAULT_LEGACY) CURRICULUM_ENABLED_DEFAULT_LEGACY)
else: else:
return False return False
...@@ -142,9 +134,7 @@ def get_data_routing(param_dict): ...@@ -142,9 +134,7 @@ def get_data_routing(param_dict):
def get_data_routing_enabled(param_dict): def get_data_routing_enabled(param_dict):
if DATA_ROUTING in param_dict.keys(): if DATA_ROUTING in param_dict.keys():
return get_scalar_param(param_dict[DATA_ROUTING], return get_scalar_param(param_dict[DATA_ROUTING], DATA_ROUTING_ENABLED, DATA_ROUTING_ENABLED_DEFAULT)
DATA_ROUTING_ENABLED,
DATA_ROUTING_ENABLED_DEFAULT)
else: else:
return False return False
...@@ -164,9 +154,7 @@ def get_random_ltd(param_dict): ...@@ -164,9 +154,7 @@ def get_random_ltd(param_dict):
def get_random_ltd_enabled(param_dict): def get_random_ltd_enabled(param_dict):
if RANDOM_LTD in param_dict.keys(): if RANDOM_LTD in param_dict.keys():
return get_scalar_param(param_dict[RANDOM_LTD], return get_scalar_param(param_dict[RANDOM_LTD], RANDOM_LTD_ENABLED, RANDOM_LTD_ENABLED_DEFAULT)
RANDOM_LTD_ENABLED,
RANDOM_LTD_ENABLED_DEFAULT)
else: else:
return False return False
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
######################################### # DeepSpeed Team
# Data efficiency library """
# See sample config at https://www.deepspeed.ai/docs/config-json/#data-efficiency Data efficiency library
######################################### See sample config at https://www.deepspeed.ai/docs/config-json/data-efficiency
"""
DATA_EFFICIENCY = "data_efficiency" DATA_EFFICIENCY = "data_efficiency"
DATA_EFFICIENCY_ENABLED = "enabled" DATA_EFFICIENCY_ENABLED = "enabled"
DATA_EFFICIENCY_ENABLED_DEFAULT = False DATA_EFFICIENCY_ENABLED_DEFAULT = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment