Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
''' # Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
import os import os
...@@ -10,14 +11,8 @@ from collections import OrderedDict ...@@ -10,14 +11,8 @@ from collections import OrderedDict
from deepspeed.runtime import ZeROOptimizer from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage,
get_global_norm, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
empty_cache,
see_memory_usage,
inf,
is_model_parallel_parameter,
align_dense_tensors,
all_gather_dp_groups)
from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
...@@ -30,14 +25,9 @@ from deepspeed.runtime.constants import PIPE_REPLICATED ...@@ -30,14 +25,9 @@ from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.checkpoint.constants import (DS_VERSION, from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT,
GROUP_PADDINGS, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, CLIP_GRAD,
PARTITION_COUNT, ZERO_STAGE, PARAM_SLICE_MAPPINGS)
SINGLE_PARTITION_OF_FP32_GROUPS,
BASE_OPTIMIZER_STATE,
CLIP_GRAD,
ZERO_STAGE,
PARAM_SLICE_MAPPINGS)
from deepspeed.utils import link_hp_params from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint import enable_universal_checkpoint
...@@ -53,10 +43,8 @@ def input(msg): ...@@ -53,10 +43,8 @@ def input(msg):
def split_half_float_double(tensors): def split_half_float_double(tensors):
device_type = get_accelerator().device_name() device_type = get_accelerator().device_name()
dtypes = [ dtypes = [
"torch.{}.HalfTensor".format(device_type), "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
"torch.{}.FloatTensor".format(device_type), "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type)
"torch.{}.DoubleTensor".format(device_type),
"torch.{}.BFloat16Tensor".format(device_type)
] ]
buckets = [] buckets = []
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
...@@ -110,6 +98,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -110,6 +98,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
For usage examples, refer to TODO: DeepSpeed Tutorial For usage examples, refer to TODO: DeepSpeed Tutorial
""" """
def __init__(self, def __init__(self,
init_optimizer, init_optimizer,
param_names, param_names,
...@@ -168,6 +157,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -168,6 +157,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# ZeRO stage 1 (False) or 2 (True) # ZeRO stage 1 (False) or 2 (True)
self.partition_gradients = partition_grads self.partition_gradients = partition_grads
self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1"
self.timers = timers self.timers = timers
...@@ -179,8 +169,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -179,8 +169,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.deepspeed_adam_offload = cpu_offload self.deepspeed_adam_offload = cpu_offload
self.device = get_accelerator().current_device_name( self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'
) if not self.cpu_offload else 'cpu'
self.dp_process_group = dp_process_group self.dp_process_group = dp_process_group
...@@ -195,9 +184,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -195,9 +184,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
#For MoE models this maybe different for different param group #For MoE models this maybe different for different param group
#It will be modified during MoE setup later in the init #It will be modified during MoE setup later in the init
self.real_dp_process_group = [ self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
dp_process_group for i in range(len(self.optimizer.param_groups))
]
self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))] self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))]
self.is_gradient_accumulation_boundary = True self.is_gradient_accumulation_boundary = True
...@@ -233,12 +220,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -233,12 +220,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
if self.fp16_master_weights_and_gradients: if self.fp16_master_weights_and_gradients:
assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32. Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}. Either disable fp16_master_weights_and_gradients or enable ZeRO-2 Offload with DeepSpeedCPUAdam" assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \
f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."\
f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \
f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam."
if self.reduce_scatter: if self.reduce_scatter:
assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-2 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
# param flattened by groups # param flattened by groups
self.bit16_groups = [] self.bit16_groups = []
...@@ -272,7 +263,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -272,7 +263,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# align nccl all-gather send buffers to 4-byte boundary # align nccl all-gather send buffers to 4-byte boundary
self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2 self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2
assert (allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} " assert (
allgather_bucket_size % self.nccl_start_alignment_factor == 0
), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
self.all_reduce_print = False self.all_reduce_print = False
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
...@@ -289,9 +282,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -289,9 +282,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# push this group to list before modify # push this group to list before modify
# TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
trainable_parameters = [ trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
param for param in param_group['params'] if param.requires_grad
]
self.bit16_groups.append(trainable_parameters) self.bit16_groups.append(trainable_parameters)
# not sure why apex was cloning the weights before flattening # not sure why apex was cloning the weights before flattening
...@@ -309,9 +300,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -309,9 +300,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
if self.round_robin_gradients: if self.round_robin_gradients:
round_robin_tensors, round_robin_indices = self._round_robin_reorder( round_robin_tensors, round_robin_indices = self._round_robin_reorder(
self.bit16_groups[i], self.bit16_groups[i], dist.get_world_size(group=self.real_dp_process_group[i]))
dist.get_world_size(group=self.real_dp_process_group[i])
)
else: else:
round_robin_tensors = self.bit16_groups[i] round_robin_tensors = self.bit16_groups[i]
round_robin_indices = list(range(len(self.bit16_groups[i]))) round_robin_indices = list(range(len(self.bit16_groups[i])))
...@@ -323,15 +312,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -323,15 +312,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.bit16_groups_flat.append( self.bit16_groups_flat.append(
self.flatten_dense_tensors_aligned( self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i], self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to(
dist.get_world_size(group=self.real_dp_process_group[i])).to(
get_accelerator().current_device_name())) get_accelerator().current_device_name()))
see_memory_usage(f"After flattening and moving param group {i} to GPU", see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)
force=False)
# Record padding required for alignment # Record padding required for alignment
if partition_id == dist.get_world_size( if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
group=self.real_dp_process_group[i]) - 1:
padding = self.bit16_groups_flat[i].numel() - sum( padding = self.bit16_groups_flat[i].numel() - sum(
[t.numel() for t in self.round_robin_bit16_groups[i]]) [t.numel() for t in self.round_robin_bit16_groups[i]])
else: else:
...@@ -339,36 +325,29 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -339,36 +325,29 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.groups_padding.append(padding) self.groups_padding.append(padding)
if dist.get_rank(group=self.real_dp_process_group[i]) == 0: if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
see_memory_usage( see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)
f"After Flattening and after emptying param group {i} cache",
force=False)
# set model bit16 weight to slices of flattened buffer # set model bit16 weight to slices of flattened buffer
self._update_model_bit16_weights(i) self._update_model_bit16_weights(i)
# divide the flat weights into near equal partition equal to the data parallel degree # divide the flat weights into near equal partition equal to the data parallel degree
# each process will compute on a different part of the partition # each process will compute on a different part of the partition
data_parallel_partitions = self.get_data_parallel_partitions( data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
self.bit16_groups_flat[i],
i)
self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)
# verify that data partition start locations are 4-byte aligned # verify that data partition start locations are 4-byte aligned
for partitioned_data in data_parallel_partitions: for partitioned_data in data_parallel_partitions:
assert (partitioned_data.data_ptr() % assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)
(2 * self.nccl_start_alignment_factor) == 0)
# A partition of the fp32 master weights that will be updated by this process. # A partition of the fp32 master weights that will be updated by this process.
# Note that the params in single_partition_of_fp32_groups is cloned and detached # Note that the params in single_partition_of_fp32_groups is cloned and detached
# from the origin params of the model. # from the origin params of the model.
if not fp16_master_weights_and_gradients: if not fp16_master_weights_and_gradients:
self.single_partition_of_fp32_groups.append( self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().float().detach())
self.device).clone().float().detach())
else: else:
self.single_partition_of_fp32_groups.append( self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().half().detach())
self.device).clone().half().detach())
# Set local optimizer to have flat params of its own partition. # Set local optimizer to have flat params of its own partition.
# After this, the local optimizer will only contain its own partition of params. # After this, the local optimizer will only contain its own partition of params.
...@@ -377,12 +356,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -377,12 +356,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
i].requires_grad = True # keep this in case internal optimizer uses it i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = [self.single_partition_of_fp32_groups[i]] param_group['params'] = [self.single_partition_of_fp32_groups[i]]
partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size( partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_dp_process_group[i])
group=self.real_dp_process_group[i])
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info( params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
self.round_robin_bit16_groups[i], self.round_robin_bit16_groups[i], partition_size, partition_id)
partition_size,
partition_id)
self.partition_size.append(partition_size) self.partition_size.append(partition_size)
self.params_in_partition.append(params_in_partition) self.params_in_partition.append(params_in_partition)
...@@ -399,8 +375,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -399,8 +375,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.reduce_bucket_size = int(reduce_bucket_size) self.reduce_bucket_size = int(reduce_bucket_size)
self.allgather_bucket_size = int(allgather_bucket_size) self.allgather_bucket_size = int(allgather_bucket_size)
self.reduction_event = get_accelerator().Event(enable_timing=False, self.reduction_event = get_accelerator().Event(enable_timing=False, blocking=False)
blocking=False)
self.reduction_stream = get_accelerator().Stream() self.reduction_stream = get_accelerator().Stream()
self.cpu_computation_stream = get_accelerator().Stream() self.cpu_computation_stream = get_accelerator().Stream()
self.copy_grad_stream = get_accelerator().Stream() self.copy_grad_stream = get_accelerator().Stream()
...@@ -449,18 +424,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -449,18 +424,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.local_overflow = False self.local_overflow = False
self.grad_position = {} self.grad_position = {}
self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory( self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(
torch.zeros(largest_param_numel, torch.zeros(largest_param_numel, device=self.device, dtype=self.dtype))
device=self.device, self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel,
dtype=self.dtype)) device=get_accelerator().current_device_name(),
self.temp_grad_buffer_for_gpu_offload = torch.zeros( dtype=self.dtype)
largest_param_numel,
device=get_accelerator().current_device_name(),
dtype=self.dtype)
for i, params_group in enumerate(self.bit16_groups): for i, params_group in enumerate(self.bit16_groups):
self.get_grad_position(i, self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i])
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i])
# mapping from parameter to partition that it belongs to # mapping from parameter to partition that it belongs to
self.param_to_partition_ids = {} self.param_to_partition_ids = {}
...@@ -537,8 +506,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -537,8 +506,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
for lp in self.bit16_groups[i]: for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None: if lp._hp_mapping is not None:
lp_name = self.param_names[lp] lp_name = self.param_names[lp]
param_mapping_per_group[ param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
lp_name] = lp._hp_mapping.get_hp_fragment_address()
param_mapping.append(param_mapping_per_group) param_mapping.append(param_mapping_per_group)
return param_mapping return param_mapping
...@@ -553,17 +521,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -553,17 +521,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bit16_groups_flat[i].numel() // dp_world_size partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
flat_hp_partition = self.single_partition_of_fp32_groups[i] flat_hp_partition = self.single_partition_of_fp32_groups[i]
link_hp_params( link_hp_params(lp_param_list=self.bit16_groups[i],
lp_param_list=self.bit16_groups[i], flat_hp_partition=flat_hp_partition,
flat_hp_partition=flat_hp_partition, gradient_dict=self.averaged_gradients,
gradient_dict=self.averaged_gradients, offload_gradient_dict=self.offload_gradient_dict,
offload_gradient_dict=self.offload_gradient_dict, use_offload=self.cpu_offload,
use_offload=self.cpu_offload, param_group_index=i,
param_group_index=i, partition_start=partition_id * partition_size,
partition_start=partition_id * partition_size, partition_size=partition_size,
partition_size=partition_size, partition_optimizer_state=self.optimizer.state[flat_hp_partition],
partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i])
dp_group=self.real_dp_process_group[i])
def is_moe_group(self, group): def is_moe_group(self, group):
return 'moe' in group and group['moe'] return 'moe' in group and group['moe']
...@@ -575,19 +542,19 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -575,19 +542,19 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
if not self.partition_gradients and not self.contiguous_gradients: if not self.partition_gradients and not self.contiguous_gradients:
logger.warn( logger.warn(
"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental." "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
)
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE" assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
assert any([self.is_moe_group(group) for group in self.optimizer.param_groups]), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer" assert any(
[self.is_moe_group(group) for group in self.optimizer.param_groups]
), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
self.is_moe_param_group = [] self.is_moe_param_group = []
for i, group in enumerate(self.optimizer.param_groups): for i, group in enumerate(self.optimizer.param_groups):
if self.is_moe_group(group): if self.is_moe_group(group):
assert all([is_moe_param(param) for param in group['params']]), "All params in MoE group must be MoE params" assert all([is_moe_param(param)
self.real_dp_process_group[i] = self.expert_dp_process_group[ for param in group['params']]), "All params in MoE group must be MoE params"
group['name']] self.real_dp_process_group[i] = self.expert_dp_process_group[group['name']]
self.partition_count[i] = dist.get_world_size( self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']])
group=self.expert_dp_process_group[group['name']])
self.is_moe_param_group.append(True) self.is_moe_param_group.append(True)
else: else:
self.is_moe_param_group.append(False) self.is_moe_param_group.append(False)
...@@ -638,14 +605,19 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -638,14 +605,19 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def initialize_optimizer_states(self): def initialize_optimizer_states(self):
for i, group in enumerate(self.bit16_groups): for i, group in enumerate(self.bit16_groups):
single_grad_partition = torch.zeros( single_grad_partition = torch.zeros(int(self.partition_size[i]),
int(self.partition_size[i]), dtype=self.single_partition_of_fp32_groups[i].dtype,
dtype=self.single_partition_of_fp32_groups[i].dtype, device=self.device)
device=self.device)
self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
single_grad_partition) if self.cpu_offload else single_grad_partition single_grad_partition) if self.cpu_offload else single_grad_partition
self.optimizer.step() # Initialize the optimizer states with the flattended fp32 partition.
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
if isinstance(self.optimizer, torch.optim.Adagrad):
self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
else:
self.optimizer.step()
if not self.cpu_offload: if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups: for group in self.single_partition_of_fp32_groups:
...@@ -709,11 +681,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -709,11 +681,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.total_grads_in_partition[i][partition_id] = 0 self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(i, param_group, partition_id) self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][ self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index(
partition_id] = self.get_first_param_index( i, param_group, partition_id)
i,
param_group,
partition_id)
def independent_gradient_partition_epilogue(self): def independent_gradient_partition_epilogue(self):
self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
...@@ -742,13 +711,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -742,13 +711,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
device=get_accelerator().current_device_name(), device=get_accelerator().current_device_name(),
return_tensor_list=True) return_tensor_list=True)
else: else:
avg_new = self.get_flat_partition( avg_new = self.get_flat_partition(self.params_in_partition[i],
self.params_in_partition[i], self.first_offset[i],
self.first_offset[i], self.partition_size[i],
self.partition_size[i], dtype=self.dtype,
dtype=self.dtype, device=get_accelerator().current_device_name(),
device=get_accelerator().current_device_name(), return_tensor_list=True)
return_tensor_list=True)
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new): for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
accumulated_grad.add_(new_avg_grad) accumulated_grad.add_(new_avg_grad)
...@@ -769,13 +737,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -769,13 +737,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
total_partitions = dist.get_world_size(group=self.real_dp_process_group[i]) total_partitions = dist.get_world_size(group=self.real_dp_process_group[i])
for partition_id in range(total_partitions): for partition_id in range(total_partitions):
self.is_partition_reduced[i][partition_id] = False self.is_partition_reduced[i][partition_id] = False
self.remaining_grads_in_partition[i][ self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id]
partition_id] = self.total_grads_in_partition[i][partition_id]
for param_id in self.is_grad_computed[i][partition_id]: for param_id in self.is_grad_computed[i][partition_id]:
self.is_grad_computed[i][partition_id][param_id] = False self.is_grad_computed[i][partition_id][param_id] = False
def initialize_gradient_partition(self, i, param_group, partition_id): def initialize_gradient_partition(self, i, param_group, partition_id):
def set_key_value_list(dictionary, key, value): def set_key_value_list(dictionary, key, value):
if key in dictionary: if key in dictionary:
dictionary[key].append(value) dictionary[key].append(value)
...@@ -802,25 +770,20 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -802,25 +770,20 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
param_id = self.get_param_id(param) param_id = self.get_param_id(param)
if (current_index >= start_index and current_index < end_index): if (current_index >= start_index and current_index < end_index):
set_key_value_list(self.param_to_partition_ids[i], set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id) increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False self.is_grad_computed[i][partition_id][param_id] = False
self.grad_partition_insertion_offset[i][partition_id][ self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index
param_id] = current_index - start_index
self.grad_start_offset[i][partition_id][param_id] = 0 self.grad_start_offset[i][partition_id][param_id] = 0
elif start_index > current_index and start_index < (current_index + elif start_index > current_index and start_index < (current_index + param_size):
param_size): assert (first_offset == 0
assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" ), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index first_offset = start_index - current_index
set_key_value_list(self.param_to_partition_ids[i], set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
param_id,
partition_id)
increment_value(self.total_grads_in_partition[i], partition_id) increment_value(self.total_grads_in_partition[i], partition_id)
self.is_grad_computed[i][partition_id][param_id] = False self.is_grad_computed[i][partition_id][param_id] = False
...@@ -869,14 +832,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -869,14 +832,12 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
############### Independent Partition Gradient ######################## ############### Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
param.numel())
self.reduce_ipg_grads() self.reduce_ipg_grads()
if self.contiguous_gradients and self.overlap_comm: if self.contiguous_gradients and self.overlap_comm:
# Swap ipg_index between 0 and 1 # Swap ipg_index between 0 and 1
self.ipg_index = 1 - self.ipg_index self.ipg_index = 1 - self.ipg_index
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())
param.numel())
param_id = self.get_param_id(param) param_id = self.get_param_id(param)
assert self.params_already_reduced[param_id] == False, \ assert self.params_already_reduced[param_id] == False, \
...@@ -884,17 +845,14 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -884,17 +845,14 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
Gradient computed twice for this partition. \ Gradient computed twice for this partition. \
Multiple gradient reduction is currently not supported" Multiple gradient reduction is currently not supported"
if param.numel() > self.reduce_bucket_size: if self.contiguous_gradients:
self.extra_large_param_to_reduce = param if param.numel() > self.reduce_bucket_size:
self.extra_large_param_to_reduce = param
elif self.contiguous_gradients: else:
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow( new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
0, new_grad_tensor.copy_(param.grad.view(-1))
self.elements_in_ipg_bucket, param.grad.data = new_grad_tensor.data.view_as(param.grad)
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.numel() self.elements_in_ipg_bucket += param.numel()
...@@ -969,13 +927,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -969,13 +927,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
#Otherwise averaging is done at the entire buffer level at the end of the loop #Otherwise averaging is done at the entire buffer level at the end of the loop
# MoE param have different groups # MoE param have different groups
if self.ipg_bucket_has_moe_params: if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[ process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param.group_name] if is_moe_param( param) else self.dp_process_group
param) else self.dp_process_group
param.grad.data.div_(dist.get_world_size(group=process_group)) param.grad.data.div_(dist.get_world_size(group=process_group))
partition_ids = self.param_to_partition_ids[i][param_id] partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}" assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}"
partition_size = self.partition_size[i] partition_size = self.partition_size[i]
# Get all partition ids + their offsets # Get all partition ids + their offsets
partition_ids_w_offsets = [] partition_ids_w_offsets = []
...@@ -1025,10 +983,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1025,10 +983,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# dist.barrier() # dist.barrier()
#dist.barrier() #dist.barrier()
dst_rank = dist.get_global_rank(real_dp_process_group[i], dst) dst_rank = dist.get_global_rank(real_dp_process_group[i], dst)
async_handle = dist.reduce(grad_slice, async_handle = dist.reduce(grad_slice, dst=dst_rank, group=real_dp_process_group[i], async_op=True)
dst=dst_rank,
group=real_dp_process_group[i],
async_op=True)
async_handles.append(async_handle) async_handles.append(async_handle)
for handle in async_handles: for handle in async_handles:
...@@ -1060,10 +1015,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1060,10 +1015,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
num_elements = partition_size - current_offset num_elements = partition_size - current_offset
self.grad_position[param_id] = [ self.grad_position[param_id] = [
int(group_id), int(group_id), int(param_start_offset),
int(param_start_offset), int(current_offset), int(num_elements)
int(current_offset),
int(num_elements)
] ]
current_offset += num_elements current_offset += num_elements
...@@ -1077,10 +1030,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1077,10 +1030,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
for lp_param in self.params_in_partition[param_group_index]: for lp_param in self.params_in_partition[param_group_index]:
param_id = self.get_param_id(lp_param) param_id = self.get_param_id(lp_param)
[_, _, dest_offset, num_elements] = self.grad_position[param_id] [_, _, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[ dest_tensor = self.single_partition_of_fp32_groups[param_group_index].grad.view(-1).narrow(
param_group_index].grad.view(-1).narrow(0, 0, dest_offset, num_elements)
dest_offset,
num_elements)
self.offload_gradient_dict[param_group_index].append(dest_tensor) self.offload_gradient_dict[param_group_index].append(dest_tensor)
def async_accumulate_grad_in_cpu_via_gpu(self, param): def async_accumulate_grad_in_cpu_via_gpu(self, param):
...@@ -1089,55 +1040,35 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1089,55 +1040,35 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
# copy to a preexisiting buffer to avoid memory allocation penalty # copy to a preexisiting buffer to avoid memory allocation penalty
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel())
0,
0,
param.numel())
#buffer for storing gradients for this parameter in CPU #buffer for storing gradients for this parameter in CPU
def buffer_to_accumulate_to_in_cpu(): def buffer_to_accumulate_to_in_cpu():
if not self.fp16_master_weights_and_gradients: if not self.fp16_master_weights_and_gradients:
return get_accelerator().pin_memory( return get_accelerator().pin_memory(torch.zeros(param.numel(), dtype=param.dtype, device=self.device))
torch.zeros(param.numel(),
dtype=param.dtype,
device=self.device))
else: else:
return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow( return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
0,
dest_offset,
num_elements)
#accumulate gradients into param.grad or parts of it that belongs to this partition #accumulate gradients into param.grad or parts of it that belongs to this partition
def accumulate_gradients(): def accumulate_gradients():
if not self.fp16_master_weights_and_gradients: if not self.fp16_master_weights_and_gradients:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True)
non_blocking=True)
param.grad.data.view(-1).add_(dest_buffer) param.grad.data.view(-1).add_(dest_buffer)
else: else:
dest_buffer.narrow(0, dest_buffer.narrow(0, source_offset,
source_offset, num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
num_elements).copy_( non_blocking=True)
self.accumulated_grads_in_cpu[param_id].view(-1), param.grad.data.view(-1).narrow(0, source_offset,
non_blocking=True) num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements))
param.grad.data.view(-1).narrow(
0,
source_offset,
num_elements).add_(dest_buffer.narrow(0,
source_offset,
num_elements))
#move accumulated gradients back to CPU #move accumulated gradients back to CPU
def copy_gradients_to_cpu(): def copy_gradients_to_cpu():
if not self.fp16_master_weights_and_gradients: if not self.fp16_master_weights_and_gradients:
self.accumulated_grads_in_cpu[param_id].data.copy_( self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1), non_blocking=True)
param.grad.data.view(-1),
non_blocking=True)
else: else:
self.accumulated_grads_in_cpu[param_id].data.copy_( self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1).narrow(
param.grad.data.view(-1).narrow(0, 0, source_offset, num_elements),
source_offset, non_blocking=True)
num_elements),
non_blocking=True)
if param_id not in self.accumulated_grads_in_cpu: if param_id not in self.accumulated_grads_in_cpu:
self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu() self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu()
...@@ -1177,10 +1108,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1177,10 +1108,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow( dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
0,
dest_offset,
num_elements)
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements) src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements)
if not self.fp16_master_weights_and_gradients: if not self.fp16_master_weights_and_gradients:
...@@ -1220,16 +1148,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1220,16 +1148,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
op=dist.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type) total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float( if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1 total_norm = -1
return total_norm return total_norm
...@@ -1258,17 +1183,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1258,17 +1183,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
total_size += param_in_partition.numel() total_size += param_in_partition.numel()
see_memory_usage(f"before copying {total_size} gradients into partition") see_memory_usage(f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty( self.grads_in_partition = torch.empty(int(total_size),
int(total_size), dtype=self.dtype,
dtype=self.dtype, device=get_accelerator().current_device_name())
device=get_accelerator().current_device_name())
see_memory_usage(f"after copying {total_size} gradients into partition") see_memory_usage(f"after copying {total_size} gradients into partition")
# The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
new_grad_tensor = self.grads_in_partition.view(-1).narrow( new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel())
0,
self.grads_in_partition_offset,
param.numel())
new_grad_tensor.copy_(param.grad.view(-1)) new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad) param.grad.data = new_grad_tensor.data.view_as(param.grad)
#print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}") #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
...@@ -1279,17 +1200,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1279,17 +1200,16 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if self.extra_large_param_to_reduce is not None: if self.extra_large_param_to_reduce is not None:
assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen" assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
_, _, param_id = self.params_in_ipg_bucket[0] _, _, param_id = self.params_in_ipg_bucket[0]
assert self.get_param_id( assert self.get_param_id(self.extra_large_param_to_reduce
self.extra_large_param_to_reduce) == param_id, "param in ipg bucket does not match extra-large param" ) == param_id, "param in ipg bucket does not match extra-large param"
self.average_tensor(self.extra_large_param_to_reduce.grad.view(-1)) self.average_tensor(self.extra_large_param_to_reduce.grad.view(-1))
self.extra_large_param_to_reduce = None self.extra_large_param_to_reduce = None
else: else:
self.average_tensor(self.ipg_buffer[self.ipg_index]) self.average_tensor(self.ipg_buffer[self.ipg_index])
else: else:
self.buffered_reduce_fallback( self.buffered_reduce_fallback(None,
None, self.grads_in_ipg_bucket,
self.grads_in_ipg_bucket, elements_per_buffer=self.elements_in_ipg_bucket)
elements_per_buffer=self.elements_in_ipg_bucket)
if self.overlap_comm: if self.overlap_comm:
stream = self.reduction_stream stream = self.reduction_stream
...@@ -1324,8 +1244,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1324,8 +1244,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
elif self.contiguous_gradients: elif self.contiguous_gradients:
self.copy_grads_in_partition(param) self.copy_grads_in_partition(param)
else: # zero stage 1 - partition only optimizer state else: # zero stage 1 - partition only optimizer state
if self.contiguous_gradients and self.is_param_in_current_partition[ if self.contiguous_gradients and self.is_param_in_current_partition[param_id]:
param_id]:
self.copy_grads_in_partition(param) self.copy_grads_in_partition(param)
self.grads_in_ipg_bucket = [] self.grads_in_ipg_bucket = []
...@@ -1339,6 +1258,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1339,6 +1258,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.reduce_independent_p_g_buckets_and_remove_grads(param, i) self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
def zero_reduced_gradients(self, partition_id, i): def zero_reduced_gradients(self, partition_id, i):
def are_all_related_partitions_reduced(params_id): def are_all_related_partitions_reduced(params_id):
for partition_id in self.param_to_partition_ids[i][params_id]: for partition_id in self.param_to_partition_ids[i][params_id]:
if not self.is_partition_reduced[i][partition_id]: if not self.is_partition_reduced[i][partition_id]:
...@@ -1358,29 +1278,23 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1358,29 +1278,23 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.sequential_execution(print_func, message) self.sequential_execution(print_func, message)
def get_grads_to_reduce(self, i, partition_id): def get_grads_to_reduce(self, i, partition_id):
def get_reducible_portion(key): def get_reducible_portion(key):
grad = self.param_dict[key].grad grad = self.param_dict[key].grad
total_elements = grad.numel() total_elements = grad.numel()
start = self.grad_start_offset[i][partition_id][key] start = self.grad_start_offset[i][partition_id][key]
num_elements = min( num_elements = min(total_elements - start,
total_elements - start, self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key])
self.partition_size[i] -
self.grad_partition_insertion_offset[i][partition_id][key])
if not pg_correctness_test: if not pg_correctness_test:
if num_elements == total_elements: if num_elements == total_elements:
return grad return grad
else: else:
return grad.contiguous().view(-1).narrow(0, return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements))
int(start),
int(num_elements))
else: else:
if num_elements == total_elements: if num_elements == total_elements:
return grad.clone() return grad.clone()
else: else:
return grad.clone().contiguous().view(-1).narrow( return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements))
0,
int(start),
int(num_elements))
grads_to_reduce = [] grads_to_reduce = []
for key in self.is_grad_computed[i][partition_id]: for key in self.is_grad_computed[i][partition_id]:
...@@ -1456,11 +1370,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1456,11 +1370,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced) buf.copy_(synced)
def allreduce_no_retain(self, def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None):
bucket,
numel_per_bucket=500000000,
rank=None,
log=None):
small_bucket = [] small_bucket = []
numel = 0 numel = 0
for tensor in bucket: for tensor in bucket:
...@@ -1475,18 +1385,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1475,18 +1385,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# allows using reduction of gradients instead of using all_reduce # allows using reduction of gradients instead of using all_reduce
def buffered_reduce_fallback(self, def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None):
rank,
grads,
elements_per_buffer=500000000,
log=None):
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
self.allreduce_no_retain(bucket, self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log)
numel_per_bucket=elements_per_buffer,
rank=rank,
log=log)
############################################################################# #############################################################################
############################################################################# #############################################################################
...@@ -1531,11 +1434,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1531,11 +1434,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if (current_index >= start_index and current_index < end_index): if (current_index >= start_index and current_index < end_index):
params_in_partition.append(tensor) params_in_partition.append(tensor)
elif start_index > current_index and start_index < (current_index + elif start_index > current_index and start_index < (current_index + tensor_size):
tensor_size):
params_in_partition.append(tensor) params_in_partition.append(tensor)
assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" assert (first_offset == 0
), "This can happen either zero or only once as this must be the first tensor in the partition"
first_offset = start_index - current_index first_offset = start_index - current_index
else: else:
...@@ -1589,9 +1492,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1589,9 +1492,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if norm_type == inf: if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients) total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group)
op=dist.ReduceOp.MAX,
group=self.dp_process_group)
# Take max across all GPUs. # Take max across all GPUs.
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
...@@ -1609,16 +1510,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1609,16 +1510,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
total_norm += param_norm.item()**2 total_norm += param_norm.item()**2
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
op=dist.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type) total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float( if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1 total_norm = -1
return total_norm return total_norm
...@@ -1626,13 +1524,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1626,13 +1524,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# creates a flat fused tensor from the tensor list starting at the first_offset # creates a flat fused tensor from the tensor list starting at the first_offset
# in the first tensor of the list. If there are not enough elements in the tensor # in the first tensor of the list. If there are not enough elements in the tensor
# list then the flat tensor will be padded with zeros # list then the flat tensor will be padded with zeros
def get_flat_partition(self, def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False):
tensor_list,
first_offset,
partition_size,
dtype,
device,
return_tensor_list=False):
flat_tensor_list = [] flat_tensor_list = []
current_size = 0 current_size = 0
for i, tensor in enumerate(tensor_list): for i, tensor in enumerate(tensor_list):
...@@ -1655,10 +1547,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1655,10 +1547,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# we need a narrow view of the tensor based on the tensor offset and number of elements that # we need a narrow view of the tensor based on the tensor offset and number of elements that
# we need from this tensor # we need from this tensor
if tensor_offset > 0 or num_elements < tensor.numel(): if tensor_offset > 0 or num_elements < tensor.numel():
flat_tensor_list.append(tensor.contiguous().view(-1).narrow( flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements)))
0,
int(tensor_offset),
int(num_elements)))
else: else:
flat_tensor_list.append(tensor) flat_tensor_list.append(tensor)
...@@ -1666,10 +1555,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1666,10 +1555,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening # this means its the last partition and does not align with the dp boundary. We need to pad before flattening
if current_size < partition_size: if current_size < partition_size:
flat_tensor_list.append( flat_tensor_list.append(torch.zeros(int(partition_size - current_size), dtype=dtype, device=device))
torch.zeros(int(partition_size - current_size),
dtype=dtype,
device=device))
if return_tensor_list: if return_tensor_list:
return flat_tensor_list return flat_tensor_list
...@@ -1715,9 +1601,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1715,9 +1601,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def override_loss_scale(self, loss_scale): def override_loss_scale(self, loss_scale):
if loss_scale != self.external_loss_scale: if loss_scale != self.external_loss_scale:
logger.info( logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}'
)
self.custom_loss_scaler = True self.custom_loss_scaler = True
self.external_loss_scale = loss_scale self.external_loss_scale = loss_scale
...@@ -1727,14 +1611,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1727,14 +1611,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
for i, group in enumerate(self.bit16_groups): for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload: if self.cpu_offload:
norm_groups.append( norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]))
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else: else:
norm_groups.append( norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
if self.has_moe_layers: if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups) self._average_expert_grad_norms(norm_groups)
...@@ -1745,18 +1625,18 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1745,18 +1625,18 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def get_bit16_param_group(self, group_no): def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no] bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
partition_id = dist.get_rank(group=self.real_dp_process_group[group_no]) partition_id = dist.get_rank(group=self.real_dp_process_group[group_no])
return [ return [bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]]
bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]
]
def _optimizer_step(self, group_no): def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]] self.optimizer.param_groups = [original_param_groups[group_no]]
from deepspeed.ops.adam import DeepSpeedCPUAdam # Disabling this as the C++ side copy & synchornize is not working correctly
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: #from deepspeed.ops.adam import DeepSpeedCPUAdam
self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) #if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
else: # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
self.optimizer.step() #else:
# self.optimizer.step()
self.optimizer.step()
self.optimizer.param_groups = original_param_groups self.optimizer.param_groups = original_param_groups
def step(self, closure=None): def step(self, closure=None):
...@@ -1777,12 +1657,6 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1777,12 +1657,6 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
prev_scale = self.loss_scale prev_scale = self.loss_scale
self._update_scale(self.overflow) self._update_scale(self.overflow)
if self.overflow: if self.overflow:
if dist.get_rank() == 0:
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
if self.dtype == torch.half:
overflow_msg += f" Attempted loss scale: {prev_scale}, reducing to {self.loss_scale}"
logger.info(overflow_msg)
see_memory_usage('After overflow before clearing gradients') see_memory_usage('After overflow before clearing gradients')
self.zero_grad(set_to_none=True) self.zero_grad(set_to_none=True)
if self.cpu_offload: if self.cpu_offload:
...@@ -1797,29 +1671,34 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1797,29 +1671,34 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
return return
# Step 1:- Calculate gradient norm using fp-16 grads # Step 1:- Calculate gradient norm using fp-16 grads
see_memory_usage('Before norm calculation') if self.dtype == torch.float16:
scaled_global_grad_norm = self.scaled_global_norm() see_memory_usage('Before norm calculation')
self._global_grad_norm = scaled_global_grad_norm / prev_scale scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / prev_scale
see_memory_usage('After norm before optimizer')
see_memory_usage('After norm before optimizer')
# Step 2:- run optimizer and upscaling simultaneously # Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups): for i, group in enumerate(self.bit16_groups):
self.start_timers([OPTIMIZER_GRADIENTS]) self.start_timers([OPTIMIZER_GRADIENTS])
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload: if self.cpu_offload:
single_grad_partition = self.single_partition_of_fp32_groups[i].grad single_grad_partition = self.single_partition_of_fp32_groups[i].grad
self.unscale_and_clip_grads([single_grad_partition], if self.dtype == torch.float16:
scaled_global_grad_norm) self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS]) self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP]) self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i) self._optimizer_step(i)
from deepspeed.ops.adam import DeepSpeedCPUAdam # Disabled, this is not currently working
if not (type(self.optimizer) == DeepSpeedCPUAdam #from deepspeed.ops.adam import DeepSpeedCPUAdam
and self.dtype == torch.half): #if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half):
bit16_partitions = self.parallel_partitioned_bit16_groups[i] # bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i] # fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data) # bit16_partitions[partition_id].data.copy_(fp32_partition.data)
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP]) self.stop_timers([OPTIMIZER_STEP])
else: else:
...@@ -1828,12 +1707,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1828,12 +1707,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# create a flat gradients for parameters updated by this process # create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size( if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
group=self.real_dp_process_group[i]) - 1:
single_grad_partition = self.flatten_dense_tensors_aligned( single_grad_partition = self.flatten_dense_tensors_aligned(
self.averaged_gradients[i], self.averaged_gradients[i],
int(self.partition_size[i])).to( int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype)
self.single_partition_of_fp32_groups[i].dtype)
else: else:
single_grad_partition = self.flatten(self.averaged_gradients[i]).to( single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype) self.single_partition_of_fp32_groups[i].dtype)
...@@ -1847,8 +1724,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1847,8 +1724,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.averaged_gradients[i] = None self.averaged_gradients[i] = None
self.unscale_and_clip_grads([single_grad_partition], if self.dtype == torch.float16:
scaled_global_grad_norm) self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS]) self.stop_timers([OPTIMIZER_GRADIENTS])
# Step 3:- run the optimizer if no offloading # Step 3:- run the optimizer if no offloading
...@@ -1869,11 +1747,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1869,11 +1747,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.start_timers([OPTIMIZER_ALLGATHER]) self.start_timers([OPTIMIZER_ALLGATHER])
# Gather the updated weights from everyone. # Gather the updated weights from everyone.
# Then all partitions of the model parameters are updated and ready for next round forward. # Then all partitions of the model parameters are updated and ready for next round forward.
all_gather_dp_groups( all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups,
partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group,
dp_process_group=self.real_dp_process_group, start_alignment_factor=self.nccl_start_alignment_factor,
start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size)
allgather_bucket_size=self.allgather_bucket_size)
self.stop_timers([OPTIMIZER_ALLGATHER]) self.stop_timers([OPTIMIZER_ALLGATHER])
...@@ -1888,24 +1765,23 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1888,24 +1765,23 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
@torch.no_grad() @torch.no_grad()
def update_lp_params(self): def update_lp_params(self):
for i, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): for i, (bit16_partitions, fp32_partition) in enumerate(
zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
bit16_partitions[partition_id].data.copy_(fp32_partition.data) bit16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
# if i == 0: # if i == 0:
# print_rank_0(f'{fp32_partition[:10]=}', force=True) # print_rank_0(f'{fp32_partition[:10]=}', force=True)
all_gather_dp_groups( all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups,
partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group,
dp_process_group=self.real_dp_process_group, start_alignment_factor=self.nccl_start_alignment_factor,
start_alignment_factor=self.nccl_start_alignment_factor, allgather_bucket_size=self.allgather_bucket_size)
allgather_bucket_size=self.allgather_bucket_size)
def _average_expert_grad_norms(self, norm_groups): def _average_expert_grad_norms(self, norm_groups):
for i, norm in enumerate(norm_groups): for i, norm in enumerate(norm_groups):
if self.is_moe_param_group[i]: if self.is_moe_param_group[i]:
scaled_norm = norm * 1.0 / float( scaled_norm = norm * 1.0 / float(dist.get_world_size(group=self.real_dp_process_group[i]))
dist.get_world_size(group=self.real_dp_process_group[i]))
scaled_norm_tensor = torch.tensor(scaled_norm, scaled_norm_tensor = torch.tensor(scaled_norm,
device=get_accelerator().device_name(), device=get_accelerator().device_name(),
dtype=torch.float) dtype=torch.float)
...@@ -1949,14 +1825,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -1949,14 +1825,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def has_overflow(self, partition_gradients=True): def has_overflow(self, partition_gradients=True):
if partition_gradients: if partition_gradients:
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial( overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
)
overflow_gpu = get_accelerator().ByteTensor([overflow]) overflow_gpu = get_accelerator().ByteTensor([overflow])
'''This will capture overflow across all data parallel and expert parallel process '''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process''' Since expert parallel process are a subset of data parallel process'''
dist.all_reduce(overflow_gpu, dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
op=dist.ReduceOp.MAX,
group=self.dp_process_group)
else: else:
params = [] params = []
...@@ -2093,9 +1966,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2093,9 +1966,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
optimizer_groups_state = [] optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups): for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0] p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding( lean_optimizer_state = self._get_state_without_padding(self.optimizer.state[p], self.groups_padding[i])
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state) optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state return optimizer_groups_state
...@@ -2123,8 +1994,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2123,8 +1994,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict() state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
# Remove paddings for DP alignment to enable loading for other alignment values # Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding( fp32_groups_without_padding = self._get_groups_without_padding(self.single_partition_of_fp32_groups)
self.single_partition_of_fp32_groups)
state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding
state_dict[ state_dict[
...@@ -2146,17 +2016,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2146,17 +2016,13 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
for i in range(len(self.single_partition_of_fp32_groups)): for i in range(len(self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_id = dist.get_rank(group=self.real_dp_process_group[i])
merged_partitions = [ merged_partitions = [sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict]
sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict
]
if self.is_moe_group(self.optimizer.param_groups[i]): if self.is_moe_group(self.optimizer.param_groups[i]):
ranks = self.get_ep_ranks( ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])
group_name=self.optimizer.param_groups[i]['name'])
merged_partitions = [merged_partitions[i] for i in ranks] merged_partitions = [merged_partitions[i] for i in ranks]
flat_merged_partitions = self.flatten_dense_tensors_aligned( flat_merged_partitions = self.flatten_dense_tensors_aligned(
merged_partitions, merged_partitions,
self.nccl_start_alignment_factor * self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]))
dist.get_world_size(group=self.real_dp_process_group[i]))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id]) merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
...@@ -2165,7 +2031,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2165,7 +2031,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
# Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights # Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights
def _restore_from_bit16_weights(self): def _restore_from_bit16_weights(self):
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): for group_id, (bit16_partitions, fp32_partition) in enumerate(
zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
fp32_partition.data.copy_(bit16_partitions[partition_id].data) fp32_partition.data.copy_(bit16_partitions[partition_id].data)
...@@ -2178,11 +2045,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2178,11 +2045,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
alignment = dist.get_world_size(group=self.real_dp_process_group[group_id]) alignment = dist.get_world_size(group=self.real_dp_process_group[group_id])
if torch.is_tensor(all_partition_states[0]): if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = self.flatten_dense_tensors_aligned( flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment)
all_partition_states, dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id)
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions,
group_id)
return dp_partitions[partition_id] return dp_partitions[partition_id]
else: else:
# Assume non-tensor states are not partitioned and equal across ranks, so return first one # Assume non-tensor states are not partitioned and equal across ranks, so return first one
...@@ -2217,25 +2081,15 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2217,25 +2081,15 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
base_optimizer_group_states = [] base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)): for i in range(len(self.optimizer.param_groups)):
partition_states = {} partition_states = {}
all_partition_group_states = [ all_partition_group_states = [sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict]
sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict
]
if self.is_moe_group(self.optimizer.param_groups[i]): if self.is_moe_group(self.optimizer.param_groups[i]):
ranks = self.get_ep_ranks( ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])
group_name=self.optimizer.param_groups[i]['name']) all_partition_group_states = [all_partition_group_states[i] for i in ranks]
all_partition_group_states = [
all_partition_group_states[i] for i in ranks
]
for key in all_partition_group_states[0].keys(): for key in all_partition_group_states[0].keys():
all_partition_states = [ all_partition_states = [all_states[key] for all_states in all_partition_group_states]
all_states[key] for all_states in all_partition_group_states partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i)
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states,
i)
base_optimizer_group_states.append(partition_states) base_optimizer_group_states.append(partition_states)
self._restore_base_optimizer_state(base_optimizer_group_states) self._restore_base_optimizer_state(base_optimizer_group_states)
...@@ -2246,18 +2100,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2246,18 +2100,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
load_from_fp32_weights=False, load_from_fp32_weights=False,
checkpoint_folder=None): checkpoint_folder=None):
if checkpoint_folder: if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
load_optimizer_states,
load_from_fp32_weights)
else: else:
self._load_legacy_checkpoint(state_dict_list, self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
load_optimizer_states,
load_from_fp32_weights) def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
def _load_universal_checkpoint(self,
checkpoint_folder,
load_optimizer_states,
load_from_fp32_weights):
self._load_hp_checkpoint_state(checkpoint_folder) self._load_hp_checkpoint_state(checkpoint_folder)
@property @property
...@@ -2274,16 +2121,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2274,16 +2121,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
for lp in self.bit16_groups[i]: for lp in self.bit16_groups[i]:
if lp._hp_mapping is not None: if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state( lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
os.path.join(checkpoint_dir, tp_world_size)
self.param_names[lp]),
tp_rank, def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
tp_world_size)
def _load_legacy_checkpoint(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint r"""Loading ZeRO checkpoint
Arguments: Arguments:
...@@ -2314,8 +2155,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2314,8 +2155,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
dp_rank = dist.get_rank(group=self.dp_process_group) dp_rank = dist.get_rank(group=self.dp_process_group)
current_rank_sd = state_dict_list[dp_rank] current_rank_sd = state_dict_list[dp_rank]
self.loss_scaler = current_rank_sd.get('loss_scaler', self.loss_scaler) self.loss_scaler = current_rank_sd.get('loss_scaler', self.loss_scaler)
self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', self.dynamic_loss_scale)
self.dynamic_loss_scale)
self.overflow = current_rank_sd.get('overflow', self.overflow) self.overflow = current_rank_sd.get('overflow', self.overflow)
self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
...@@ -2353,8 +2193,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2353,8 +2193,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self._restore_elastic_base_optimizer_state(state_dict_list) self._restore_elastic_base_optimizer_state(state_dict_list)
else: else:
# loading an elastic checkpoint into rigid exec # loading an elastic checkpoint into rigid exec
self._restore_base_optimizer_state( self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE])
current_rank_sd[BASE_OPTIMIZER_STATE])
# At this point, the optimizer's references to the model's fp32 parameters are up to date. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date. # The optimizer's hyperparameters and internal buffers are also up to date.
...@@ -2377,7 +2216,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): ...@@ -2377,7 +2216,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self._restore_from_elastic_fp32_weights(state_dict_list) self._restore_from_elastic_fp32_weights(state_dict_list)
else: else:
# For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient. # For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient.
for current, saved in zip(self.single_partition_of_fp32_groups, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): for current, saved in zip(self.single_partition_of_fp32_groups,
current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
src_tensor = _get_padded_tensor(saved, current.numel()) src_tensor = _get_padded_tensor(saved, current.numel())
current.data.copy_(src_tensor.data) current.data.copy_(src_tensor.data)
else: else:
...@@ -2397,9 +2237,7 @@ def _handle_overflow(cpu_sum, x, i): ...@@ -2397,9 +2237,7 @@ def _handle_overflow(cpu_sum, x, i):
if not math.isfinite(float(v)): if not math.isfinite(float(v)):
t_i = v_i t_i = v_i
break break
logger.info( logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}")
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
def estimate_zero2_model_states_mem_needs(total_params, def estimate_zero2_model_states_mem_needs(total_params,
...@@ -2422,9 +2260,7 @@ def estimate_zero2_model_states_mem_needs(total_params, ...@@ -2422,9 +2260,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
def model_to_params(model): def model_to_params(model):
# shared params calculated only once # shared params calculated only once
total_params = sum( total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
dict((p.data_ptr(),
p.numel()) for p in model.parameters()).values())
return total_params return total_params
...@@ -2452,11 +2288,10 @@ def estimate_zero2_model_states_mem_needs_all_live(model, ...@@ -2452,11 +2288,10 @@ def estimate_zero2_model_states_mem_needs_all_live(model,
total_params = model_to_params(model) total_params = model_to_params(model)
estimate_zero2_model_states_mem_needs_all_cold( estimate_zero2_model_states_mem_needs_all_cold(total_params=total_params,
total_params=total_params, num_gpus_per_node=num_gpus_per_node,
num_gpus_per_node=num_gpus_per_node, num_nodes=num_nodes,
num_nodes=num_nodes, additional_buffer_factor=additional_buffer_factor)
additional_buffer_factor=additional_buffer_factor)
def estimate_zero2_model_states_mem_needs_all_cold(total_params, def estimate_zero2_model_states_mem_needs_all_cold(total_params,
...@@ -2480,6 +2315,7 @@ def estimate_zero2_model_states_mem_needs_all_cold(total_params, ...@@ -2480,6 +2315,7 @@ def estimate_zero2_model_states_mem_needs_all_cold(total_params,
- ``additional_buffer_factor``: estimation factor (defaults to 1.5): - ``additional_buffer_factor``: estimation factor (defaults to 1.5):
""" """
def format_options(cpu_offload): def format_options(cpu_offload):
enabled = [] enabled = []
device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else "none" device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else "none"
...@@ -2488,19 +2324,16 @@ def estimate_zero2_model_states_mem_needs_all_cold(total_params, ...@@ -2488,19 +2324,16 @@ def estimate_zero2_model_states_mem_needs_all_cold(total_params,
nodes_str = "nodes" if num_nodes > 1 else "node" nodes_str = "nodes" if num_nodes > 1 else "node"
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
print( print("Estimated memory needed for params, optim states and gradients for a:\n"
"Estimated memory needed for params, optim states and gradients for a:\n" f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" f"SW: Model with {int(total_params/1e6)}M total params.")
f"SW: Model with {int(total_params/1e6)}M total params.")
print(" per CPU | per GPU | Options") print(" per CPU | per GPU | Options")
for cpu_offload in [True, False]: for cpu_offload in [True, False]:
cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs( cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(total_params=total_params,
total_params=total_params, num_gpus_per_node=num_gpus_per_node,
num_gpus_per_node=num_gpus_per_node, num_nodes=num_nodes,
num_nodes=num_nodes, cpu_offload=cpu_offload,
cpu_offload=cpu_offload, additional_buffer_factor=additional_buffer_factor)
additional_buffer_factor=additional_buffer_factor
)
options_str = format_options(cpu_offload=cpu_offload) options_str = format_options(cpu_offload=cpu_offload)
print(f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") print(f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}")
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from deepspeed.runtime.zero.contiguous_memory_allocator import ContiguousMemoryAllocator from deepspeed.runtime.zero.contiguous_memory_allocator import ContiguousMemoryAllocator
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
import deepspeed import deepspeed
...@@ -27,6 +30,7 @@ def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=Fals ...@@ -27,6 +30,7 @@ def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=Fals
class TiledLinear(torch.nn.Module): class TiledLinear(torch.nn.Module):
def __init__(self, def __init__(self,
in_features, in_features,
out_features, out_features,
...@@ -114,10 +118,7 @@ class TiledLinear(torch.nn.Module): ...@@ -114,10 +118,7 @@ class TiledLinear(torch.nn.Module):
local_bias = bias if in_id == (in_splits - 1) else False local_bias = bias if in_id == (in_splits - 1) else False
local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id] local_in_dim = self.in_parts[in_id + 1] - self.in_parts[in_id]
local = linear_cls(local_in_dim, local = linear_cls(local_in_dim, local_out_dim, bias=local_bias, **kwargs)
local_out_dim,
bias=local_bias,
**kwargs)
self.linears[out_id].append(local) self.linears[out_id].append(local)
# Optionally initialize with a known tensor # Optionally initialize with a known tensor
...@@ -127,13 +128,12 @@ class TiledLinear(torch.nn.Module): ...@@ -127,13 +128,12 @@ class TiledLinear(torch.nn.Module):
def forward(self, input_): def forward(self, input_):
if self.in_splits > 1 and not self.input_is_already_split: if self.in_splits > 1 and not self.input_is_already_split:
input_parts = partition(input_.shape[-1], self.in_splits) input_parts = partition(input_.shape[-1], self.in_splits)
split_sizes = [ split_sizes = [input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)]
input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)
]
inputs = self._split_global_input(input_, split_sizes) inputs = self._split_global_input(input_, split_sizes)
elif self.in_splits > 1: elif self.in_splits > 1:
inputs = input_ inputs = input_
assert len(inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}" assert len(
inputs) == self.in_splits, f"Col splits {self.in_splits} does not match input splits {len(inputs)}"
else: else:
# no splits # no splits
inputs = [input_] inputs = [input_]
...@@ -260,6 +260,7 @@ class TiledLinearReturnBias(TiledLinear): ...@@ -260,6 +260,7 @@ class TiledLinearReturnBias(TiledLinear):
"""Wrapper for a Linear class that returns its own bias parameter, such as """Wrapper for a Linear class that returns its own bias parameter, such as
used by Megatron-LM. used by Megatron-LM.
""" """
def _reduce_local_output(self, in_id, out_id, current_out, new_out): def _reduce_local_output(self, in_id, out_id, current_out, new_out):
"""Reduces output tensors, but not the returned bias. """ """Reduces output tensors, but not the returned bias. """
if current_out is not None: if current_out is not None:
...@@ -273,10 +274,7 @@ class TiledLinearReturnBias(TiledLinear): ...@@ -273,10 +274,7 @@ class TiledLinearReturnBias(TiledLinear):
tensor, bias = new_out tensor, bias = new_out
assert tensor is not None assert tensor is not None
tensor = super()._reduce_local_output(in_id=in_id, tensor = super()._reduce_local_output(in_id=in_id, out_id=out_id, current_out=old_tensor, new_out=tensor)
out_id=out_id,
current_out=old_tensor,
new_out=tensor)
if bias is None: if bias is None:
bias = old_bias bias = old_bias
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
from typing import List from typing import List
...@@ -7,6 +10,7 @@ import torch ...@@ -7,6 +10,7 @@ import torch
from deepspeed import comm as dist from deepspeed import comm as dist
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import FusedAdam
from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.utils.nvtx import instrument_w_nvtx
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
...@@ -15,9 +19,7 @@ from deepspeed.accelerator import get_accelerator ...@@ -15,9 +19,7 @@ from deepspeed.accelerator import get_accelerator
def _initialize_parameter_parallel_groups(parameter_parallel_size=None): def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size()) data_parallel_size = int(dist.get_world_size())
parameter_parallel_size = parameter_parallel_size or data_parallel_size parameter_parallel_size = parameter_parallel_size or data_parallel_size
logger.info("data_parallel_size: %s, parameter_parallel_size: %s", logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size)
data_parallel_size,
parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \ assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size' 'world size should be divisible by parameter parallel size'
rank = dist.get_rank() rank = dist.get_rank()
...@@ -35,10 +37,7 @@ class ZeRORuntimeException(Exception): ...@@ -35,10 +37,7 @@ class ZeRORuntimeException(Exception):
ZERO_SUPPORTED_OPTIMIZERS = [ ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam, torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam, torch.optim.Adagrad, DeepSpeedCPUAdagrad
torch.optim.AdamW,
FusedAdam,
DeepSpeedCPUAdam
] ]
# Add apex FusedAdam to supported list if apex is installed # Add apex FusedAdam to supported list if apex is installed
...@@ -52,9 +51,7 @@ except ImportError: ...@@ -52,9 +51,7 @@ except ImportError:
def is_zero_supported_optimizer(optimizer): def is_zero_supported_optimizer(optimizer):
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info( logger.info(f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}')
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .logging import logger, log_dist from .logging import logger, log_dist
from .comms_logging import get_caller_func from .comms_logging import get_caller_func
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math import math
from deepspeed.utils import log_dist from deepspeed.utils import log_dist
...@@ -31,7 +34,7 @@ def calc_bw_log(comm_op, size, duration): ...@@ -31,7 +34,7 @@ def calc_bw_log(comm_op, size, duration):
if comm_op == "all_to_all_single": if comm_op == "all_to_all_single":
tput = (size / duration) tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n) busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather" or comm_op == "all_gather_base" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_base": elif comm_op == "all_gather" or comm_op == "all_gather_into_tensor" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_tensor":
size *= n size *= n
tput = (size / duration) tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n) busbw = (size / duration) * ((n - 1) / n)
...@@ -56,6 +59,7 @@ def calc_bw_log(comm_op, size, duration): ...@@ -56,6 +59,7 @@ def calc_bw_log(comm_op, size, duration):
class CommsLogger: class CommsLogger:
def __init__(self): def __init__(self):
from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT
self.comms_dict = {} self.comms_dict = {}
...@@ -111,8 +115,7 @@ class CommsLogger: ...@@ -111,8 +115,7 @@ class CommsLogger:
# TODO: Add to tensorboard # TODO: Add to tensorboard
if self.verbose: if self.verbose:
n = dist.get_world_size() n = dist.get_world_size()
log_str = f"rank={dist.get_rank()} | comm op: " + record_name + " | time (ms): {:.2f}".format( log_str = f"rank={dist.get_rank()} | comm op: " + record_name + " | time (ms): {:.2f}".format(latency)
latency)
log_str += " | msg size: " + convert_size(msg_size) log_str += " | msg size: " + convert_size(msg_size)
log_str += " | algbw (Gbps): {:.2f} ".format(algbw) log_str += " | algbw (Gbps): {:.2f} ".format(algbw)
log_str += " | busbw (Gbps): {:.2f} ".format(busbw) log_str += " | busbw (Gbps): {:.2f} ".format(busbw)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" debug utils """ # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# For lazy import with printflock() # For lazy import with printflock()
fcntl = None fcntl = None
...@@ -128,6 +130,7 @@ def log_rank_file(rank, *msgs): ...@@ -128,6 +130,7 @@ def log_rank_file(rank, *msgs):
def print_backward_tensors(tensor): def print_backward_tensors(tensor):
def _print_bwd_tensors(grad_fn): def _print_bwd_tensors(grad_fn):
print(f"Backward tensors in {grad_fn}") print(f"Backward tensors in {grad_fn}")
for funcs in grad_fn.next_functions: for funcs in grad_fn.next_functions:
...@@ -135,9 +138,7 @@ def print_backward_tensors(tensor): ...@@ -135,9 +138,7 @@ def print_backward_tensors(tensor):
try: try:
tensor = getattr(funcs[0], 'variable') tensor = getattr(funcs[0], 'variable')
print(funcs[0]) print(funcs[0])
print( print(f"Tensor - id: {id(tensor)}, shape: {tensor.shape}, data: {tensor}, grad: {tensor.grad}")
f"Tensor - id: {id(tensor)}, shape: {tensor.shape}, data: {tensor}, grad: {tensor.grad}"
)
except AttributeError as e: except AttributeError as e:
_print_bwd_tensors(funcs[0]) _print_bwd_tensors(funcs[0])
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import os
import torch
from datetime import timedelta
from .logging import logger
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
def init_distributed(dist_backend="nccl",
auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True,
timeout=default_pg_timeout,
init_method=None):
"""Initialize torch.distributed backend, potentially performing MPI discovery if needed
Arguments:
dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
distributed_port: Optional (int). torch distributed backend port
verbose: Optional (bool). verbose logging
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
"""
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
logger.info(
"Not using the DeepSpeed or torch.distributed launchers, attempting to detect MPI environment..."
)
if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose)
elif in_aws_sm():
patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
else:
mpi_discovery(distributed_port=distributed_port, verbose=verbose)
if not torch.distributed.is_initialized():
if verbose and int(os.getenv('RANK', '0')) == 0:
logger.info(
"Initializing torch distributed with backend: {}".format(dist_backend))
assert isinstance(timeout, timedelta)
torch.distributed.init_process_group(backend=dist_backend,
timeout=timeout,
init_method=init_method)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
"""
Discovery MPI environment via mpi4py and map to relevant torch.distributed state
"""
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(distributed_port)
if verbose:
logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if torch.distributed.is_initialized():
assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
rank, torch.distributed.get_rank())
assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, torch.distributed.get_world_size())
def in_aml():
# Are we running inside an Azure Machine Learning (AML) environment?
return 'AZUREML_EXPERIMENT_ID' in os.environ
def in_aws_sm():
# Are we running inside an AWS SageMaker environment?
return 'SM_TRAINING_ENV' in os.environ
def in_dlts():
# Are we running on a DLTS cluster?
return 'DLTS_JOB_ID' in os.environ
def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""Helper routine to get and set environment variables.
This is adapted from Azure ML's documentation available from:
https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
os.environ["MASTER_ADDR"] = master_node_params[0]
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = str(master_port)
else:
os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = "54965"
if verbose:
logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
if verbose:
logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
"""Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
if verbose:
logger.info(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
class DeprecatedException(Exception): class DeprecatedException(Exception):
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file # The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file
...@@ -51,8 +52,7 @@ def initialize(ep_size=1, mpu=None): ...@@ -51,8 +52,7 @@ def initialize(ep_size=1, mpu=None):
def _ensure_divisibility(numerator, denominator): def _ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator.""" """Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format( assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
numerator, denominator)
# Not currently used. Helper function to create a model (tensor) parallel group. # Not currently used. Helper function to create a model (tensor) parallel group.
...@@ -78,8 +78,7 @@ def _create_model_parallel(model_parallel_size_): ...@@ -78,8 +78,7 @@ def _create_model_parallel(model_parallel_size_):
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
""" """
log_dist(f'Creating model parallel group with size {model_parallel_size_}', log_dist(f'Creating model parallel group with size {model_parallel_size_}', ranks=[0])
ranks=[0])
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert dist.is_initialized() assert dist.is_initialized()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -121,9 +120,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_): ...@@ -121,9 +120,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_):
""" """
assert dist.is_initialized() assert dist.is_initialized()
log_dist( log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0])
f'Creating expert and data parallel groups with size {expert_parallel_size_}',
ranks=[0])
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
...@@ -139,9 +136,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_): ...@@ -139,9 +136,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_):
for i in range(expert_parallel_size_): for i in range(expert_parallel_size_):
ranks = range(i, world_size, expert_parallel_size_) ranks = range(i, world_size, expert_parallel_size_)
group = dist.new_group(ranks) group = dist.new_group(ranks)
log_dist( log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0])
f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
[0])
if i == (rank % expert_parallel_size_): if i == (rank % expert_parallel_size_):
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
...@@ -153,9 +148,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_): ...@@ -153,9 +148,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_):
for i in range(world_size // expert_parallel_size_): for i in range(world_size // expert_parallel_size_):
ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
group = dist.new_group(ranks) group = dist.new_group(ranks)
log_dist( log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0])
f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}',
[0])
if i == (rank // expert_parallel_size_): if i == (rank // expert_parallel_size_):
_EXPERT_PARALLEL_GROUP[group_name] = group _EXPERT_PARALLEL_GROUP[group_name] = group
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from typing import Callable from typing import Callable
from torch import Tensor from torch import Tensor
...@@ -35,13 +37,10 @@ class OnDevice(object): ...@@ -35,13 +37,10 @@ class OnDevice(object):
if device == "meta": if device == "meta":
if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__): if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__):
raise NotImplementedError( raise NotImplementedError("Meta tensor support is not available, please upgrade to torch 1.10+")
"Meta tensor support is not available, please upgrade to torch 1.10+"
) def fp_tensor_constructor(self, fn: Callable, target_fp_dtype: torch.dtype) -> Callable:
def fp_tensor_constructor(self,
fn: Callable,
target_fp_dtype: torch.dtype) -> Callable:
def wrapped_fn(*args, **kwargs) -> Tensor: def wrapped_fn(*args, **kwargs) -> Tensor:
if kwargs.get("device", None) is None: if kwargs.get("device", None) is None:
kwargs['device'] = self.device kwargs['device'] = self.device
...@@ -53,6 +52,7 @@ class OnDevice(object): ...@@ -53,6 +52,7 @@ class OnDevice(object):
return wrapped_fn return wrapped_fn
def get_new_tensor_fn_for_dtype(self, dtype: torch.dtype) -> Callable: def get_new_tensor_fn_for_dtype(self, dtype: torch.dtype) -> Callable:
def new_tensor(cls, *args) -> Tensor: def new_tensor(cls, *args) -> Tensor:
tensor = OnDevice._orig_torch_empty(0, device=self.device).new_empty(*args) tensor = OnDevice._orig_torch_empty(0, device=self.device).new_empty(*args)
if tensor.is_floating_point(): if tensor.is_floating_point():
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import functools
import logging import logging
import sys import sys
import os import os
...@@ -14,6 +18,7 @@ log_levels = { ...@@ -14,6 +18,7 @@ log_levels = {
class LoggerFactory: class LoggerFactory:
@staticmethod @staticmethod
def create_logger(name=None, level=logging.INFO): def create_logger(name=None, level=logging.INFO):
"""create a logger """create a logger
...@@ -29,9 +34,8 @@ class LoggerFactory: ...@@ -29,9 +34,8 @@ class LoggerFactory:
if name is None: if name is None:
raise ValueError("name for logger cannot be None") raise ValueError("name for logger cannot be None")
formatter = logging.Formatter( formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] "
"[%(asctime)s] [%(levelname)s] " "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
"[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
logger_ = logging.getLogger(name) logger_ = logging.getLogger(name)
logger_.setLevel(level) logger_.setLevel(level)
...@@ -46,6 +50,21 @@ class LoggerFactory: ...@@ -46,6 +50,21 @@ class LoggerFactory:
logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO) logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO)
@functools.lru_cache(None)
def warning_once(*args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
logger.warning(*args, **kwargs)
logger.warning_once = warning_once
def print_configuration(args, name): def print_configuration(args, name):
logger.info("{}:".format(name)) logger.info("{}:".format(name))
for arg in sorted(vars(args)): for arg in sorted(vars(args)):
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import types import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
def link_hp_params(lp_param_list, def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
flat_hp_partition, param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group):
gradient_dict, local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group)
offload_gradient_dict,
use_offload,
param_group_index,
partition_start,
partition_size,
partition_optimizer_state,
dp_group):
local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list,
partition_start,
partition_size,
dp_group)
for lp_param, lp_start in local_lp_param_and_offset: for lp_param, lp_start in local_lp_param_and_offset:
lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict,
lp_start, offload_gradient_dict, use_offload, param_group_index,
flat_hp_partition, partition_start, partition_size, partition_optimizer_state)
gradient_dict,
offload_gradient_dict,
use_offload,
param_group_index,
partition_start,
partition_size,
partition_optimizer_state)
def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group): def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
...@@ -6,6 +9,7 @@ from deepspeed.accelerator import get_accelerator ...@@ -6,6 +9,7 @@ from deepspeed.accelerator import get_accelerator
def instrument_w_nvtx(func): def instrument_w_nvtx(func):
"""decorator that causes an NVTX range to be recorded for the duration of the """decorator that causes an NVTX range to be recorded for the duration of the
function call.""" function call."""
def wrapped_fn(*args, **kwargs): def wrapped_fn(*args, **kwargs):
get_accelerator().range_push(func.__qualname__) get_accelerator().range_push(func.__qualname__)
ret_val = func(*args, **kwargs) ret_val = func(*args, **kwargs)
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
...@@ -48,10 +49,7 @@ def get_full_hp_param(self, optim_state_key=None): ...@@ -48,10 +49,7 @@ def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None: if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer, reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
0,
lp_frag_address.start,
lp_frag_address.numel)
if optim_state_key is None: if optim_state_key is None:
hp_fragment = self._hp_mapping.hp_fragment hp_fragment = self._hp_mapping.hp_fragment
else: else:
...@@ -72,21 +70,14 @@ def get_full_hp_grad(self): ...@@ -72,21 +70,14 @@ def get_full_hp_grad(self):
else: else:
gradient_dict = hp_mapping.gradient_dict gradient_dict = hp_mapping.gradient_dict
if hp_mapping.param_group_index not in gradient_dict or gradient_dict[ if hp_mapping.param_group_index not in gradient_dict or gradient_dict[hp_mapping.param_group_index] is None:
hp_mapping.param_group_index] is None: raise ValueError("Gradients are only available immediately after backward and before engine step")
raise ValueError(
"Gradients are only available immediately after backward and before engine step"
)
lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][ lp_grad_fragment = gradient_dict[hp_mapping.param_group_index][self._index_in_param_group]
self._index_in_param_group]
hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten() hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
lp_frag_address = self._hp_mapping.lp_fragment_address lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer, reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
0,
lp_frag_address.start,
lp_frag_address.numel)
if self.view(-1).shape == hp_grad_fragment.shape: if self.view(-1).shape == hp_grad_fragment.shape:
reduce_buffer.data.copy_(hp_grad_fragment.data) reduce_buffer.data.copy_(hp_grad_fragment.data)
...@@ -150,16 +141,8 @@ def safe_get_full_grad(param): ...@@ -150,16 +141,8 @@ def safe_get_full_grad(param):
return None return None
def get_hp_fragment_mapping(lp_param, def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
lp_start, param_group_index, partition_start, partition_size, optimizer_state_dict):
flat_hp_partition,
gradient_dict,
offload_gradient_dict,
use_offload,
param_group_index,
partition_start,
partition_size,
optimizer_state_dict):
lp_end = lp_param.numel() + lp_start lp_end = lp_param.numel() + lp_start
hp_start = partition_start hp_start = partition_start
hp_end = partition_start + partition_size hp_end = partition_start + partition_size
...@@ -170,25 +153,16 @@ def get_hp_fragment_mapping(lp_param, ...@@ -170,25 +153,16 @@ def get_hp_fragment_mapping(lp_param,
f'fragment start {fragment_start} should be < fragment_end {fragment_end}' f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
fragment_numel = fragment_end - fragment_start fragment_numel = fragment_end - fragment_start
hp_frag_address = fragment_address(start=fragment_start - hp_start, hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
numel=fragment_numel) hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
hp_fragment_tensor = flat_hp_partition.narrow(0,
hp_frag_address.start,
hp_frag_address.numel)
optim_fragment = { optim_fragment = {
key: value.narrow(0, key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel)
hp_frag_address.start, for key, value in optimizer_state_dict.items()
hp_frag_address.numel)
for key,
value in optimizer_state_dict.items()
if torch.is_tensor(value) and value.shape == flat_hp_partition.shape if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
} }
lp_frag_address = fragment_address(start=fragment_start - lp_start, lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
lp_fragment_tensor = lp_param.flatten().narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
return tensor_fragment(lp_fragment=lp_fragment_tensor, return tensor_fragment(lp_fragment=lp_fragment_tensor,
lp_fragment_address=lp_frag_address, lp_fragment_address=lp_frag_address,
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import time import time
from numpy import mean from numpy import mean
...@@ -18,9 +19,8 @@ except ImportError: ...@@ -18,9 +19,8 @@ except ImportError:
class CudaEventTimer(object): class CudaEventTimer(object):
def __init__(self,
start_event: get_accelerator().Event, def __init__(self, start_event: get_accelerator().Event, end_event: get_accelerator().Event):
end_event: get_accelerator().Event):
self.start_event = start_event self.start_event = start_event
self.end_event = end_event self.end_event = end_event
...@@ -32,8 +32,10 @@ class CudaEventTimer(object): ...@@ -32,8 +32,10 @@ class CudaEventTimer(object):
class SynchronizedWallClockTimer: class SynchronizedWallClockTimer:
"""Group of timers. Borrowed from Nvidia Megatron code""" """Group of timers. Borrowed from Nvidia Megatron code"""
class Timer: class Timer:
"""Timer.""" """Timer."""
def __init__(self, name): def __init__(self, name):
self.name_ = name self.name_ = name
self.started_ = False self.started_ = False
...@@ -102,14 +104,12 @@ class SynchronizedWallClockTimer: ...@@ -102,14 +104,12 @@ class SynchronizedWallClockTimer:
@staticmethod @staticmethod
def memory_usage(): def memory_usage():
alloc = "mem_allocated: {:.4f} GB".format(get_accelerator().memory_allocated() / alloc = "mem_allocated: {:.4f} GB".format(get_accelerator().memory_allocated() / (1024 * 1024 * 1024))
(1024 * 1024 * 1024)) max_alloc = "max_mem_allocated: {:.4f} GB".format(get_accelerator().max_memory_allocated() /
max_alloc = "max_mem_allocated: {:.4f} GB".format( (1024 * 1024 * 1024))
get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024)) cache = "cache_allocated: {:.4f} GB".format(get_accelerator().memory_cached() / (1024 * 1024 * 1024))
cache = "cache_allocated: {:.4f} GB".format(get_accelerator().memory_cached() / max_cache = "max_cache_allocated: {:.4f} GB".format(get_accelerator().max_memory_cached() /
(1024 * 1024 * 1024)) (1024 * 1024 * 1024))
max_cache = "max_cache_allocated: {:.4f} GB".format(
get_accelerator().max_memory_cached() / (1024 * 1024 * 1024))
return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache) return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache)
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None): def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
...@@ -135,6 +135,7 @@ class SynchronizedWallClockTimer: ...@@ -135,6 +135,7 @@ class SynchronizedWallClockTimer:
class ThroughputTimer: class ThroughputTimer:
def __init__( def __init__(
self, self,
batch_size, batch_size,
...@@ -203,23 +204,19 @@ class ThroughputTimer: ...@@ -203,23 +204,19 @@ class ThroughputTimer:
self.global_step_count, self.global_step_count,
self.avg_samples_per_sec(), self.avg_samples_per_sec(),
self.batch_size / self.step_elapsed_time, self.batch_size / self.step_elapsed_time,
round(get_accelerator().memory_allocated() / 1024**3, round(get_accelerator().memory_allocated() / 1024**3, 2),
2), round(get_accelerator().max_memory_allocated() / 1024**3, 2),
round(get_accelerator().max_memory_allocated() / 1024**3,
2),
)) ))
if self.monitor_memory: if self.monitor_memory:
virt_mem = psutil.virtual_memory() virt_mem = psutil.virtual_memory()
swap = psutil.swap_memory() swap = psutil.swap_memory()
self.logging( self.logging("epoch={}/micro_step={}/global_step={}, vm %: {}, swap %: {}".format(
"epoch={}/micro_step={}/global_step={}, vm %: {}, swap %: {}" self.epoch_count,
.format( self.micro_step_count,
self.epoch_count, self.global_step_count,
self.micro_step_count, virt_mem.percent,
self.global_step_count, swap.percent,
virt_mem.percent, ))
swap.percent,
))
self.step_elapsed_time = 0 self.step_elapsed_time = 0
def avg_samples_per_sec(self): def avg_samples_per_sec(self):
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from enum import IntEnum from enum import IntEnum
......
#!/usr/bin/env python #!/usr/bin/env python
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
...@@ -15,18 +19,25 @@ import math ...@@ -15,18 +19,25 @@ import math
import os import os
import re import re
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment. # DeepSpeed data structures it has to be available in the current python environment.
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
SINGLE_PARTITION_OF_FP32_GROUPS, FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
FP32_FLAT_GROUPS,
ZERO_STAGE,
PARTITION_COUNT, @dataclass
PARAM_SHAPES, class zero_model_state:
BUFFER_NAMES) buffers: dict()
param_shapes: dict()
shared_params: list
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
debug = 0 debug = 0
...@@ -63,39 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage): ...@@ -63,39 +74,68 @@ def get_model_state_file(checkpoint_dir, zero_stage):
return file return file
def get_optim_files(checkpoint_dir): def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too # XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
"*_optim_states.pt")),
key=natural_keys)
if len(optim_files) == 0: if len(ckpt_files) == 0:
raise FileNotFoundError( raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
return optim_files return ckpt_files
def parse_model_state(file): def get_optim_files(checkpoint_dir):
state_dict = torch.load(file, map_location=device) return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
def get_model_state_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16 def parse_model_states(files):
buffers = { zero_model_states = []
k: v.float() for file in files:
for k, state_dict = torch.load(file, map_location=device)
v in state_dict["module"].items() if k in buffer_names
}
param_shapes = state_dict[PARAM_SHAPES]
ds_version = state_dict.get(DS_VERSION, None) if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
return buffers, param_shapes, ds_version # collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# update with frozen parameters
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
if frozen_param_shapes is not None:
if debug:
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)
return zero_model_states
def parse_optim_states(files, ds_checkpoint_dir): def parse_optim_states(files, ds_checkpoint_dir):
...@@ -132,10 +172,7 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -132,10 +172,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
raise ValueError(f"unknown zero stage {zero_stage}") raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage == 2: if zero_stage == 2:
fp32_flat_groups = [ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key]
for i in range(len(state_dicts))
]
elif zero_stage == 3: elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one # if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor # flattened tensor per group - for simplicity merge them into a single tensor
...@@ -144,8 +181,7 @@ def parse_optim_states(files, ds_checkpoint_dir): ...@@ -144,8 +181,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
# will require matching the sub-lists of param_shapes for each param group flattened tensor # will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [ fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
0) for i in range(len(state_dicts))
] ]
return zero_stage, world_size, fp32_flat_groups return zero_stage, world_size, fp32_flat_groups
...@@ -163,29 +199,53 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): ...@@ -163,29 +199,53 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
optim_files = get_optim_files(ds_checkpoint_dir) optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print( print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_files = get_model_state_files(ds_checkpoint_dir)
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) zero_model_states = parse_model_states(model_files)
buffers, param_shapes, ds_version = parse_model_state(model_file) print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
print(f'Parsing checkpoint created by deepspeed=={ds_version}')
if zero_stage == 2: if zero_stage == 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
param_shapes,
fp32_flat_groups,
buffers)
elif zero_stage == 3: elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
param_shapes,
fp32_flat_groups,
buffers) def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
if debug:
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
state_dict[name] = frozen_param_fragments[name]
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes, param_shapes = zero_model_states[0].param_shapes
fp32_flat_groups,
buffers):
# Reconstruction protocol: # Reconstruction protocol:
# #
...@@ -194,8 +254,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -194,8 +254,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
if debug: if debug:
for i in range(world_size): for i in range(world_size):
for j in range(len(fp32_flat_groups[0])): for j in range(len(fp32_flat_groups[0])):
print( print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2) # XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0]) num_param_groups = len(fp32_flat_groups[0])
...@@ -204,26 +263,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -204,26 +263,16 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
merged_partitions = [sd[i] for sd in fp32_flat_groups] merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0) full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum([ avail_numel = sum(
full_single_fp32_vector.numel() [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
])
if debug: if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes]) wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum( wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params # params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution # out-of-core computing solution
...@@ -239,13 +288,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -239,13 +288,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
total_params += 1 total_params += 1
if debug: if debug:
print( print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
)
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
...@@ -268,12 +312,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, ...@@ -268,12 +312,28 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
# Sanity check # Sanity check
if offset != avail_numel: if offset != avail_numel:
raise ValueError( raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
print( _zero2_merge_frozen_params(state_dict, zero_model_states)
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
) _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict return state_dict
...@@ -285,15 +345,48 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size): ...@@ -285,15 +345,48 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
return partitioned_numel, padding_numel return partitioned_numel, padding_numel
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
param_shapes, if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
fp32_flat_groups, return
buffers):
if debug:
for i in range(world_size):
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in zero_model_states[0].frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any # param, re-consolidating each param, while dealing with padding if any
avail_numel = fp32_flat_groups[0].numel() * world_size
# merge list of dicts, preserving order # merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()} param_shapes = {k: v for d in param_shapes for k, v in d.items()}
...@@ -304,15 +397,9 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, ...@@ -304,15 +397,9 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
wanted_params = len(param_shapes) wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values()) wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") avail_numel = fp32_flat_groups[0].numel() * world_size
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Trainable params: Have {avail_numel} numels to process.")
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params # params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
...@@ -330,30 +417,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, ...@@ -330,30 +417,41 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
if debug: if debug:
print( print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
) )
# XXX: memory usage doubles here # XXX: memory usage doubles here
state_dict[name] = torch.cat( state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
offset, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel)
for i in range(world_size)),
0).narrow(0,
0,
unpartitioned_numel).view(shape)
offset += partitioned_numel offset += partitioned_numel
offset *= world_size offset *= world_size
# Sanity check # Sanity check
if offset != avail_numel: if offset != avail_numel:
raise ValueError( raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
print(
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
) state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict return state_dict
...@@ -465,16 +563,13 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): ...@@ -465,16 +563,13 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("checkpoint_dir",
"checkpoint_dir", type=str,
type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument( parser.add_argument(
"output_file", "output_file",
type=str, type=str,
help= help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
parser.add_argument("-d", "--debug", action='store_true', help="enable debug") parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -95,7 +95,7 @@ ENV PYTHON_VERSION=3 ...@@ -95,7 +95,7 @@ ENV PYTHON_VERSION=3
RUN apt-get install -y python3 python3-dev && \ RUN apt-get install -y python3 python3-dev && \
rm -f /usr/bin/python && \ rm -f /usr/bin/python && \
ln -s /usr/bin/python3 /usr/bin/python && \ ln -s /usr/bin/python3 /usr/bin/python && \
curl -O https://bootstrap.pypa.io/get-pip.py && \ curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py && \
python get-pip.py && \ python get-pip.py && \
rm get-pip.py && \ rm get-pip.py && \
pip install --upgrade pip && \ pip install --upgrade pip && \
......
GEM
remote: https://rubygems.org/
specs:
activesupport (6.0.4.6)
concurrent-ruby (~> 1.0, >= 1.0.2)
i18n (>= 0.7, < 2)
minitest (~> 5.1)
tzinfo (~> 1.1)
zeitwerk (~> 2.2, >= 2.2.2)
addressable (2.8.0)
public_suffix (>= 2.0.2, < 5.0)
coffee-script (2.4.1)
coffee-script-source
execjs
coffee-script-source (1.11.1)
colorator (1.1.0)
commonmarker (0.23.4)
ruby-enum (~> 0.5)
concurrent-ruby (1.1.10)
dnsruby (1.61.9)
simpleidn (~> 0.1)
em-websocket (0.5.3)
eventmachine (>= 0.12.9)
http_parser.rb (~> 0)
ethon (0.15.0)
ffi (>= 1.15.0)
eventmachine (1.2.7)
execjs (2.8.1)
faraday (1.10.0)
faraday-em_http (~> 1.0)
faraday-em_synchrony (~> 1.0)
faraday-excon (~> 1.1)
faraday-httpclient (~> 1.0)
faraday-multipart (~> 1.0)
faraday-net_http (~> 1.0)
faraday-net_http_persistent (~> 1.0)
faraday-patron (~> 1.0)
faraday-rack (~> 1.0)
faraday-retry (~> 1.0)
ruby2_keywords (>= 0.0.4)
faraday-em_http (1.0.0)
faraday-em_synchrony (1.0.0)
faraday-excon (1.1.0)
faraday-httpclient (1.0.1)
faraday-multipart (1.0.3)
multipart-post (>= 1.2, < 3)
faraday-net_http (1.0.1)
faraday-net_http_persistent (1.2.0)
faraday-patron (1.0.0)
faraday-rack (1.0.0)
faraday-retry (1.0.3)
ffi (1.15.5)
forwardable-extended (2.6.0)
gemoji (3.0.1)
github-pages (223)
github-pages-health-check (= 1.17.9)
jekyll (= 3.9.0)
jekyll-avatar (= 0.7.0)
jekyll-coffeescript (= 1.1.1)
jekyll-commonmark-ghpages (= 0.1.6)
jekyll-default-layout (= 0.1.4)
jekyll-feed (= 0.15.1)
jekyll-gist (= 1.5.0)
jekyll-github-metadata (= 2.13.0)
jekyll-include-cache (= 0.2.1)
jekyll-mentions (= 1.6.0)
jekyll-optional-front-matter (= 0.3.2)
jekyll-paginate (= 1.1.0)
jekyll-readme-index (= 0.3.0)
jekyll-redirect-from (= 0.16.0)
jekyll-relative-links (= 0.6.1)
jekyll-remote-theme (= 0.4.3)
jekyll-sass-converter (= 1.5.2)
jekyll-seo-tag (= 2.7.1)
jekyll-sitemap (= 1.4.0)
jekyll-swiss (= 1.0.0)
jekyll-theme-architect (= 0.2.0)
jekyll-theme-cayman (= 0.2.0)
jekyll-theme-dinky (= 0.2.0)
jekyll-theme-hacker (= 0.2.0)
jekyll-theme-leap-day (= 0.2.0)
jekyll-theme-merlot (= 0.2.0)
jekyll-theme-midnight (= 0.2.0)
jekyll-theme-minimal (= 0.2.0)
jekyll-theme-modernist (= 0.2.0)
jekyll-theme-primer (= 0.6.0)
jekyll-theme-slate (= 0.2.0)
jekyll-theme-tactile (= 0.2.0)
jekyll-theme-time-machine (= 0.2.0)
jekyll-titles-from-headings (= 0.5.3)
jemoji (= 0.12.0)
kramdown (= 2.3.1)
kramdown-parser-gfm (= 1.1.0)
liquid (= 4.0.3)
mercenary (~> 0.3)
minima (= 2.5.1)
nokogiri (>= 1.12.5, < 2.0)
rouge (= 3.26.0)
terminal-table (~> 1.4)
github-pages-health-check (1.17.9)
addressable (~> 2.3)
dnsruby (~> 1.60)
octokit (~> 4.0)
public_suffix (>= 3.0, < 5.0)
typhoeus (~> 1.3)
html-pipeline (2.14.0)
activesupport (>= 2)
nokogiri (>= 1.4)
http_parser.rb (0.8.0)
i18n (0.9.5)
concurrent-ruby (~> 1.0)
jekyll (3.9.0)
addressable (~> 2.4)
colorator (~> 1.0)
em-websocket (~> 0.5)
i18n (~> 0.7)
jekyll-sass-converter (~> 1.0)
jekyll-watch (~> 2.0)
kramdown (>= 1.17, < 3)
liquid (~> 4.0)
mercenary (~> 0.3.3)
pathutil (~> 0.9)
rouge (>= 1.7, < 4)
safe_yaml (~> 1.0)
jekyll-avatar (0.7.0)
jekyll (>= 3.0, < 5.0)
jekyll-coffeescript (1.1.1)
coffee-script (~> 2.2)
coffee-script-source (~> 1.11.1)
jekyll-commonmark (1.3.1)
commonmarker (~> 0.14)
jekyll (>= 3.7, < 5.0)
jekyll-commonmark-ghpages (0.1.6)
commonmarker (~> 0.17.6)
jekyll-commonmark (~> 1.2)
rouge (>= 2.0, < 4.0)
jekyll-default-layout (0.1.4)
jekyll (~> 3.0)
jekyll-feed (0.15.1)
jekyll (>= 3.7, < 5.0)
jekyll-gist (1.5.0)
octokit (~> 4.2)
jekyll-github-metadata (2.13.0)
jekyll (>= 3.4, < 5.0)
octokit (~> 4.0, != 4.4.0)
jekyll-include-cache (0.2.1)
jekyll (>= 3.7, < 5.0)
jekyll-mentions (1.6.0)
html-pipeline (~> 2.3)
jekyll (>= 3.7, < 5.0)
jekyll-optional-front-matter (0.3.2)
jekyll (>= 3.0, < 5.0)
jekyll-paginate (1.1.0)
jekyll-readme-index (0.3.0)
jekyll (>= 3.0, < 5.0)
jekyll-redirect-from (0.16.0)
jekyll (>= 3.3, < 5.0)
jekyll-relative-links (0.6.1)
jekyll (>= 3.3, < 5.0)
jekyll-remote-theme (0.4.3)
addressable (~> 2.0)
jekyll (>= 3.5, < 5.0)
jekyll-sass-converter (>= 1.0, <= 3.0.0, != 2.0.0)
rubyzip (>= 1.3.0, < 3.0)
jekyll-sass-converter (1.5.2)
sass (~> 3.4)
jekyll-seo-tag (2.7.1)
jekyll (>= 3.8, < 5.0)
jekyll-sitemap (1.4.0)
jekyll (>= 3.7, < 5.0)
jekyll-swiss (1.0.0)
jekyll-theme-architect (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-cayman (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-dinky (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-hacker (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-leap-day (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-merlot (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-midnight (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-minimal (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-modernist (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-primer (0.6.0)
jekyll (> 3.5, < 5.0)
jekyll-github-metadata (~> 2.9)
jekyll-seo-tag (~> 2.0)
jekyll-theme-slate (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-tactile (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-theme-time-machine (0.2.0)
jekyll (> 3.5, < 5.0)
jekyll-seo-tag (~> 2.0)
jekyll-titles-from-headings (0.5.3)
jekyll (>= 3.3, < 5.0)
jekyll-watch (2.2.1)
listen (~> 3.0)
jemoji (0.12.0)
gemoji (~> 3.0)
html-pipeline (~> 2.2)
jekyll (>= 3.0, < 5.0)
kramdown (2.3.1)
rexml
kramdown-parser-gfm (1.1.0)
kramdown (~> 2.0)
liquid (4.0.3)
listen (3.7.1)
rb-fsevent (~> 0.10, >= 0.10.3)
rb-inotify (~> 0.9, >= 0.9.10)
mercenary (0.3.6)
mini_portile2 (2.8.0)
minima (2.5.1)
jekyll (>= 3.5, < 5.0)
jekyll-feed (~> 0.9)
jekyll-seo-tag (~> 2.1)
minimal-mistakes-jekyll (4.24.0)
jekyll (>= 3.7, < 5.0)
jekyll-feed (~> 0.1)
jekyll-gist (~> 1.5)
jekyll-include-cache (~> 0.1)
jekyll-paginate (~> 1.1)
jekyll-sitemap (~> 1.3)
minitest (5.15.0)
multipart-post (2.1.1)
nokogiri (1.13.4)
mini_portile2 (~> 2.8.0)
racc (~> 1.4)
octokit (4.22.0)
faraday (>= 0.9)
sawyer (~> 0.8.0, >= 0.5.3)
pathutil (0.16.2)
forwardable-extended (~> 2.6)
public_suffix (4.0.7)
racc (1.6.0)
rb-fsevent (0.11.1)
rb-inotify (0.10.1)
ffi (~> 1.0)
rexml (3.2.5)
rouge (3.26.0)
ruby-enum (0.9.0)
i18n
ruby2_keywords (0.0.5)
rubyzip (2.3.2)
safe_yaml (1.0.5)
sass (3.7.4)
sass-listen (~> 4.0.0)
sass-listen (4.0.0)
rb-fsevent (~> 0.9, >= 0.9.4)
rb-inotify (~> 0.9, >= 0.9.7)
sawyer (0.8.2)
addressable (>= 2.3.5)
faraday (> 0.8, < 2.0)
simpleidn (0.2.1)
unf (~> 0.1.4)
terminal-table (1.8.0)
unicode-display_width (~> 1.1, >= 1.1.1)
thread_safe (0.3.6)
typhoeus (1.4.0)
ethon (>= 0.9.0)
tzinfo (1.2.9)
thread_safe (~> 0.1)
tzinfo-data (1.2021.5)
tzinfo (>= 1.0.0)
unf (0.1.4)
unf_ext
unf_ext (0.0.8)
unicode-display_width (1.8.0)
wdm (0.1.1)
zeitwerk (2.5.4)
PLATFORMS
ruby
DEPENDENCIES
github-pages
jekyll-feed
jekyll-include-cache
jekyll-paginate
jekyll-remote-theme
minimal-mistakes-jekyll
tzinfo (~> 1.2)
tzinfo-data
wdm (~> 0.1.1)
BUNDLED WITH
2.3.8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment