Commit 67ea635f authored by aiss's avatar aiss
Browse files

push dsv0.8.2 version

parent 1b2721ad
Pipeline #201 failed with stages
in 0 seconds
'''Copyright The Microsoft DeepSpeed Team'''
from .reshape_meg_2d import reshape_meg_2d_parallel
from .deepspeed_checkpoint import DeepSpeedCheckpoint
from .utils import (get_layer_ckpt_name_for_rank,
get_model_ckpt_name_for_rank,
get_zero_ckpt_name_for_rank)
from .reshape_utils import (merge_state)
from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
from .zero_checkpoint import ZeROCheckpoint
from .universal_checkpoint import enable_universal_checkpoint
from .constants import *
'''Copyright The Microsoft DeepSpeed Team'''
'''
Various symbolic constants used for model checkpointing
'''
......@@ -11,15 +12,54 @@ FP32_FLAT_GROUPS = 'fp32_flat_groups'
BASE_OPTIMIZER_STATE = 'base_optimizer_state'
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
GROUPS_PADDING = 'groups_padding'
GROUP_PADDINGS = 'group_paddings'
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
CLIP_GRAD = 'clip_grad'
FP32_WEIGHT_KEY = "fp32"
#########################################
# Module checkpoint keys
#########################################
PARAM = 'param'
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'
#########################################
# Checkpoint naming constants
#########################################
MODEL_FILE_PREFIX = 'mp_rank_'
ZERO_FILE_PREFIX = 'zero_pp_rank_'
OPTIM_FILE_SUFFIX = '_optim_states.pt'
MODEL_FILE_SUFFIX = '_model_states.pt'
LAYER_FILE_PREFIX = 'layer_'
BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX
FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX
#########################################
# Checkpoint utility keys
#########################################
DS_VERSION = 'ds_version'
#########################################
# Universal Checkpoint keys
#########################################
UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info'
UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
# Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2
# Vocabulary padding
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
PADDED_VOCAB_SIZE = 'padded_vocab_size'
ORIGINAL_VOCAB_SIZE = 'original_vocab_size'
# Parameter splitting/merging
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
CAT_DIM = "cat_dim"
# Regex list of parameters that require special handling
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
'''Copyright The Microsoft DeepSpeed Team'''
import os
from typing import Dict
import torch
from .reshape_3d_utils import model_3d_desc
from .reshape_utils import (basic_folder_validation,
merge_state,
partition_data,
get_files,
get_files_with_prefix)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map
from .zero_checkpoint import ZeROCheckpoint
from .constants import *
EMBEDDING_LAYER_INDEX = 0
FINAL_LAYER_NORM_INDEX = -1
ARGS_KEY = 'args'
CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'
SEQUENTIAL_LAYERS = [
'input_layernorm.weight',
'input_layernorm.bias',
'self_attention.dense.bias',
'post_attention_layernorm.weight',
'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
]
LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1}
class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)
self.zero_checkpoint = ZeROCheckpoint(dir)
self.file_list = get_files(dir)
self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX)
self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.tp_degree = self.zero_checkpoint.get_src_tp_degree(
) if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree(
) if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree(
) if dp_degree is None else dp_degree
self.original_world_size = self.zero_checkpoint.get_src_tp_degree(
) * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_dp_degree()
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree
self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
self.zero_checkpoint.get_src_tp_degree())
self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel(
old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
if self.is_change_pp_degree() or self.is_change_tp_degree(
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
model_3d_desc(self.pp_degree,
self.tp_degree,
self.dp_degree))
self.global_state = {}
self._sanity_check()
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(
FINAL_LAYER_NORM_INDEX)
self._build_global_state()
def is_change_tp_degree(self):
return self.tp_degree != self.zero_checkpoint.get_src_tp_degree()
def is_change_pp_degree(self):
return self.pp_degree != self.zero_checkpoint.get_src_pp_degree()
def is_change_dp_degree(self):
return self.dp_degree != self.zero_checkpoint.get_src_dp_degree()
def show_2d_mapping(self):
print(f'reshaped 2d map ---- begin')
for i in range(self.pp_degree):
for j in range(self.tp_degree):
file_list = self.get_2d_parallel_files(pp_index=i, tp_index=j)
print(f'[{i}, {j}] = {file_list}')
print(f'reshaped 2d map ---- end')
def show_tp_embedding_map(self):
self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers')
def show_tp_final_norm_map(self):
self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers')
def show_pp_tranformer_map(self):
self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers')
def show_transformer_file_map(self):
self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files')
def _build_global_state(self):
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None)
def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict:
return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index,
tp_index=tp_index,
dp_index=dp_index,
keys_to_ignore=[PARAM_SHAPES])
def get_zero_files(self, pp_index, tp_index, dp_index) -> list:
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index,
tp_index=tp_index,
dp_index=dp_index)
def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]
def get_final_norm_layer_id(self):
return self.layer_keys[FINAL_LAYER_NORM_INDEX]
def get_iteration(self):
if not ITERATION_KEY in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0)
return self.global_state[ITERATION_KEY]
def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [
torch.load(fname,
map_location=torch.device('cpu'))
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list)
return sd
def get_embedding_files(self, tp_index: int) -> list:
assert tp_index in self.tp_to_embedding_map.keys()
return self.tp_to_embedding_map[tp_index]
def _get_checkpoint_value(self, key):
if not key in self.global_state:
sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu'))
self.global_state[key] = sd.get(key, None)
return self.global_state[key]
def get_args(self):
return self._get_checkpoint_value(ARGS_KEY)
def get_checkpoint_info(self, info_key=CHECKPOINT_INFO_KEY):
return self._get_checkpoint_value(info_key)
def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
merged_sd = None
for sd in sd_list:
if merged_sd is None:
merged_sd = sd
else:
merged_sd = merge_state(merged_sd, sd)
return merged_sd
def get_transformer_state(self, tp_index: int, pp_index: int) -> list:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
def get_pp_transformer_map(self, pp_index: int) -> list:
assert pp_index < self.pp_degree
return self.pp_to_transformer_map[pp_index]
def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0],
map_location=torch.device('cpu'))
return sd
def get_final_norm_files(self, tp_index: int) -> list:
assert tp_index in self.tp_to_final_norm_map.keys()
return self.tp_to_final_norm_map[tp_index]
def _build_tp_other_layer_map(self, layer_index: int):
assert layer_index < len(self.layer_files)
layer_files = get_files_with_prefix(self.layer_files,
self.layer_keys[layer_index])
layer_file_partitions = partition_data(layer_files, self.tp_degree)
data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
return data_map
def get_2d_parallel_files(self, tp_index: int, pp_index: int) -> list:
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
file_indices = self.new_2d_map.get_data(pp_index=pp_index, tp_index=tp_index)
return [self.mp_rank_files[i] for i in file_indices]
def _build_pp_transformer_map(self):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
for i in range(0,
self.pp_degree)
}
return data_map
def _dump_mapping(self, data_map, map_tag=None):
if map_tag is not None:
print(f'Dump mapping: {map_tag}')
for k, v in data_map.items():
print(f'{k} = {v}')
def _build_transformer_file_map(self):
transformer_layer_keys = self.layer_keys[1:-1]
file_map = {}
# XXX: this is not guaranteed
layers_per_pp = len(transformer_layer_keys) // self.pp_degree
if layers_per_pp == 0:
layers_per_pp = 1
#print(f"{transformer_layer_keys} {layers_per_pp}")
for key_index, layer_key in enumerate(transformer_layer_keys):
pp_index = key_index // layers_per_pp
layer_files = get_files_with_prefix(self.layer_files, layer_key)
layer_file_partitions = partition_data(layer_files, self.tp_degree)
for tp_index in range(self.tp_degree):
map_key = (tp_index, pp_index)
if not map_key in file_map.keys():
file_map[map_key] = []
file_map[map_key].append(layer_file_partitions[tp_index])
return file_map
def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.layer_keys) > 2
assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0
# XXX: fix me - isn't always the case
# only true with --pp-partition-method 'type:transformer|embedding' \
# assert (len(self.layer_keys) - 2) % self.pp_degree == 0
def validate_files(self):
for file in self.file_list:
if not os.path.isfile(file):
print(f'Error: {file} is not existent')
def _get_layer_keys(self):
key_set = set()
key_len = len(LAYER_FILE_PREFIX) + 2
for file_path in self.layer_files:
_, fname = os.path.split(file_path)
key_set.add(fname[:key_len])
return sorted(list(key_set))
def _merge_state_dicts(self, sd_list):
merged_sd = {}
for key in sd_list[0].keys():
if not key in SEQUENTIAL_LAYERS:
cat_dim = LAYER_CONCAT_DIM.get(key, 0)
merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
else:
merged_sd[key] = sd_list[0][key]
return merged_sd
def _validate_folder(self, dir):
basic_folder_validation(dir)
file_list = get_files(dir)
for file_prefix in [
MODEL_FILE_PREFIX,
LAYER_FILE_PREFIX,
f'{LAYER_FILE_PREFIX}01'
]:
ckpt_files = get_files_with_prefix(file_list, file_prefix)
assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'
'''Copyright The Microsoft DeepSpeed Team'''
from .reshape_utils import (get_files,
get_files_with_prefix,
partition_data,
get_zero_files)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map)
PP_DIM = 'PP'
TP_DIM = 'TP'
DP_DIM = 'DP'
class model_3d_desc(object):
def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1):
self.pp_degree = pp_degree
self.tp_degree = tp_degree
self.dp_degree = dp_degree
def reshape(self, target_3d_desc, verbose=False):
valid_reshape, reshape_errors = self.can_reshape(target_3d_desc)
assert valid_reshape, ','.join(reshape_errors)
tgt_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.pp_degree,
old_tp_degree=self.tp_degree,
new_pp_degree=target_3d_desc.pp_degree,
new_tp_degree=target_3d_desc.tp_degree,
verbose=verbose)
flat_3d_map = flatten_dp_dimension(meg_2d_map=tgt_2d_map,
src_2d_size=self.pp_degree * self.tp_degree,
dp_degree=self.dp_degree)
return unflatten_dp_dimension(meg_2d_map=flat_3d_map,
dp_degree=target_3d_desc.dp_degree)
def get_desc(self):
return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'
def world_size(self):
return self.pp_degree * self.tp_degree * self.dp_degree
def is_valid(self, pp_index, tp_index, dp_index):
err_msg = []
valid = True
for index, degree, dim_name in [
(pp_index, self.pp_degree, PP_DIM),
(tp_index, self.tp_degree, TP_DIM),
(dp_index, self.dp_degree, DP_DIM)]:
if index >= degree:
valid = False
err_msg.append(
f'{dim_name} indexing error: index {index} >= degree {degree}')
return valid, err_msg
def can_reshape(self, target_3d_desc):
err_msg = []
if target_3d_desc.pp_degree > self.pp_degree:
err_msg.append(
f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}'
)
if target_3d_desc.tp_degree > self.tp_degree:
err_msg.append(
f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}'
)
if target_3d_desc.dp_degree > self.dp_degree:
err_msg.append(
f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}'
)
return len(err_msg) == 0, err_msg
def get_model_3d_descriptor(dir):
file_list = get_files(dir)
zero_file_list = get_zero_files(dir)
num_pp0_files = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01'))
if num_pp0_files > 0:
tp_degree = num_pp0_files
pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree
dp_degree = max(1, len(zero_file_list) // (pp_degree * tp_degree))
else:
tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX))
dp_degree = max(1, len(zero_file_list) // tp_degree)
pp_degree = 0
return model_3d_desc(pp_degree, tp_degree, dp_degree)
def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree):
new_meg_2d_map = meg_2d_parallel_map(meg_2d_map.pp_degree, meg_2d_map.tp_degree)
for pp_index in range(meg_2d_map.pp_degree):
for tp_index in range(meg_2d_map.tp_degree):
dp0_indices = meg_2d_map.get_data(pp_index, tp_index)
for idx in dp0_indices:
dpX_indices = [idx + (i * src_2d_size) for i in range(dp_degree)]
new_meg_2d_map.add_data(pp_index, tp_index, dpX_indices)
return new_meg_2d_map
def unflatten_dp_dimension(meg_2d_map, dp_degree):
pp_degree = meg_2d_map.pp_degree
tp_degree = meg_2d_map.tp_degree
meg_2d_map_list = [
meg_2d_parallel_map(pp_degree=pp_degree,
tp_degree=tp_degree) for _ in range(dp_degree)
]
for pp_index in range(pp_degree):
for tp_index in range(tp_degree):
flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index)
partitioned_dp_indices = partition_data(flat_dp_indices, dp_degree)
for dp_indices, _2d_map in zip(partitioned_dp_indices, meg_2d_map_list):
_2d_map.add_data(pp_index, tp_index, dp_indices)
return meg_2d_map_list
'''Copyright The Microsoft DeepSpeed Team'''
from .reshape_utils import partition_data
class meg_2d_parallel_map(object):
def __init__(self, pp_degree, tp_degree):
self.pp_degree = pp_degree
self.tp_degree = tp_degree
self.map = {}
def simple_init(self):
self.map = {
self._make_key(i // self.tp_degree,
i % self.tp_degree): [i]
for i in range(self.pp_degree * self.tp_degree)
}
def add_data(self, pp_index, tp_index, data):
self._validate_indices(pp_index, tp_index)
assert type(data) is list
key = self._make_key(pp_index, tp_index)
if not key in self.map.keys():
self.map[key] = []
self.map[key] += data
def get_data(self, pp_index=None, tp_index=None):
self._validate_indices(pp_index, tp_index)
pp_indices = list(range(self.pp_degree)) if pp_index is None else [pp_index]
tp_indices = list(range(self.tp_degree)) if tp_index is None else [tp_index]
result = []
for i in pp_indices:
for j in tp_indices:
result += self.map[self._make_key(i, j)]
return result
def print_data(self, tag):
print(f'{tag}')
for key, value in self.map.items():
print(f'{key} = {value}')
def _validate_indices(self, pp_index, tp_index):
assert pp_index is None or pp_index < self.pp_degree
assert tp_index is None or tp_index < self.tp_degree
def _make_key(self, i, j):
return f'{i},{j}'
def _reshape_tp_dimension(old_2d_map, new_tp_degree):
old_pp_degree = old_2d_map.pp_degree
new_2d_map = meg_2d_parallel_map(old_pp_degree, new_tp_degree)
for i in range(old_pp_degree):
ranks_for_pp_index = old_2d_map.get_data(pp_index=i, tp_index=None)
split_ranks = partition_data(ranks_for_pp_index, new_tp_degree)
for j in range(new_tp_degree):
new_2d_map.add_data(i, j, split_ranks[j])
return new_2d_map
def _reshape_pp_dimension(old_2d_map, new_pp_degree):
old_tp_degree = old_2d_map.tp_degree
new_2d_map = meg_2d_parallel_map(new_pp_degree, old_tp_degree)
for i in range(old_tp_degree):
ranks_for_tp_index = old_2d_map.get_data(pp_index=None, tp_index=i)
split_ranks = partition_data(ranks_for_tp_index, new_pp_degree)
for j in range(new_pp_degree):
new_2d_map.add_data(j, i, split_ranks[j])
return new_2d_map
def reshape_meg_2d_parallel(old_pp_degree,
old_tp_degree,
new_pp_degree,
new_tp_degree,
verbose=False):
assert new_pp_degree <= old_pp_degree
assert new_tp_degree <= old_tp_degree
old_2d_map = meg_2d_parallel_map(old_pp_degree, old_tp_degree)
old_2d_map.simple_init()
if verbose:
old_2d_map.print_data(f'original_2d_map:')
if old_tp_degree != new_tp_degree:
new_tp_map = _reshape_tp_dimension(old_2d_map, new_tp_degree)
else:
new_tp_map = old_2d_map
if verbose:
new_tp_map.print_data(f'after_tp_reshape:')
if old_pp_degree != new_pp_degree:
final_map = _reshape_pp_dimension(new_tp_map, new_pp_degree)
else:
final_map = new_tp_map
if verbose:
final_map.print_data(f'final_2d_map:')
return final_map
def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
"""
Initialize model data parallel groups.
Arguments:
tp_size: number of GPUs used to parallelize model tensor.
pp_size: number of GPUs used to parallelize model pipeline.
dp_size: number of GPUs used to parallelize model data.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
world_size = tp_size * pp_size * dp_size
print(f"\n\n*** tp={tp_size}, pp={pp_size}, dp={dp_size}, world={world_size}")
tensor_model_parallel_size = min(tp_size, world_size)
pipeline_model_parallel_size = min(pp_size, world_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
# Build the data-parallel groups.
all_dp_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_dp_group_ranks.append(list(ranks))
print("DP", all_dp_group_ranks)
# Build the model-parallel groups.
all_pp_group_ranks = []
for i in range(data_parallel_size):
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_dp_group_ranks
]
all_pp_group_ranks.append(list(ranks))
print(f"PP", all_pp_group_ranks)
# Build the tensor model-parallel groups.
all_tp_group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
all_tp_group_ranks.append(list(ranks))
print(f"TP", all_tp_group_ranks)
return all_tp_group_ranks, all_pp_group_ranks, all_dp_group_ranks
# # Build the pipeline model-parallel groups and embedding groups
# # (first and last rank in each pipeline model-parallel group).
# for i in range(num_pipeline_model_parallel_groups):
# ranks = range(i, world_size,
# num_pipeline_model_parallel_groups)
# print(f"EMB{i}", list(ranks))
def reshape(src, tgt):
"""
reshape([tp_size_src, pp_size_src, dp_size_src],
[tp_size_tgt, pp_size_tgt, dp_size_tgt])
"""
print(f"\n\n*** Reshaping: {src} => {tgt}")
tp_size_src, pp_size_src, dp_size_src = src
tp_size_tgt, pp_size_tgt, dp_size_tgt = tgt
tp_ranks1, pp_ranks1, dp_ranks1 = get_mpu_ranks(tp_size=tp_size_src, pp_size=pp_size_src, dp_size=dp_size_src)
tp_ranks2, pp_ranks2, dp_ranks2 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_src, dp_size=dp_size_src)
tp_ranks3, pp_ranks3, dp_ranks3 = get_mpu_ranks(tp_size=tp_size_tgt, pp_size=pp_size_tgt, dp_size=dp_size_src)
# handle tp contraction first
print("\n*** TP contraction:")
for i, r in enumerate(tp_ranks1):
print(f'{tp_ranks1[i]} => {tp_ranks2[i]}')
# handle pp contraction next
print("\n*** PP contraction:")
for i, r in enumerate(pp_ranks1):
print(f'{pp_ranks2[i]} => {pp_ranks3[i]}')
# easy
#reshape([2,2,1],[1,1,1])
# probably need more logic to suggest how to pack
#reshape([4,4,1],[2,2,1])
#reshape([2,4,2], [8,32,1])
# get_mpu_ranks(2,2,2)
# get_mpu_ranks(4,2,1)
# get_mpu_ranks(2,4,1)
# get_mpu_ranks(1,1,8)
'''Copyright The Microsoft DeepSpeed Team'''
import os
import torch
from collections import OrderedDict
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)
def basic_folder_validation(dir):
assert os.path.exists(dir), f'{dir} path does not exist'
assert os.path.isdir(dir), f'{dir} is not a folder'
def get_files_with_prefix(all_files, prefix):
file_list = []
for file_path in all_files:
_, fname = os.path.split(file_path)
if fname.startswith(prefix):
file_list.append(file_path)
return sorted(file_list)
def validate_files(file_list):
for file in file_list:
if not os.path.isfile(file):
print(f'Error: {file} is not existent')
def get_files(dir):
file_list = []
for root, _, files in os.walk(dir):
for file in files:
file_list.append(os.path.join(root, file))
return file_list
def get_zero_files(dir):
file_list = get_files(dir)
for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
zero_files = get_files_with_prefix(file_list, prefix)
if len(zero_files) > 0:
return zero_files
return []
def partition_data(data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
partition_size = num_elems // num_partitions
partitions_list = [
data_list[i:i + partition_size] for i in range(0,
num_elems,
partition_size)
]
return partitions_list
def _key_list_to_string(key_list):
return '.'.join(key_list)
def merge_state_dict(dict_a, dict_b, key_list):
merged_dict = type(dict_a)({})
for key, value in dict_b.items():
if key in dict_a.keys():
merged_dict[key] = merge_state(dict_a[key], dict_b[key], [str(key)])
else:
merged_dict[key] = value
return merged_dict
def merge_state_list(list_a, list_b, key_list):
if len(list_a) != len(list_b):
print(f'{_key_list_to_string(key_list)}')
raise ValueError(
f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}'
)
return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)]
def merge_state(state_a, state_b, key_list=[]):
if type(state_a) != type(state_b):
key_list_string = _key_list_to_string(key_list)
print(f'key_list = {key_list_string}')
raise ValueError(
f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
if type(state_a) in (dict, OrderedDict):
return merge_state_dict(state_a, state_b, key_list)
elif type(state_a) in (list, tuple):
return type(state_a)(merge_state_list(state_a, state_b, key_list))
elif torch.is_tensor(state_a):
return torch.cat([state_a, state_b], 0)
else:
return state_a
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY,
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
CAT_DIM)
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
optim_state_keys = hp_mapping.get_optim_state_keys()
hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys
checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys}
for file in checkpoint_files.values():
assert os.path.isfile(file), f'{file} is not a valid file'
for key in hp_keys:
ckpt_file = checkpoint_files[key]
ckpt_dict = torch.load(ckpt_file)
full_hp_param = ckpt_dict[PARAM]
# need to deal with slices that were averaged.
# the opposite of averaging here becomes an exact copy of the first slice
# I thought of 2 ways:
# implementation a. find a way for a client to pass a dict with patterns
# if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS):
# tp_rank = 0
# tp_world_size = 1
# the other approach is to assume that the saved data is correct and if full_hp_param.shape ==
# self.shape that means we automatically copy?
# implementation b.
# this version requires no additional data passed from the client
# if the shapes already match it must be slices that were averaged - so we just hack around those
if full_hp_param.shape == self.shape:
tp_rank = 0
tp_world_size = 1
# special case for word_embeddings weights which get padded differently depending on TP degree.
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(
VOCAB_DIVISIBILITY_PADDING_TENSOR,
None)
if vocab_divisibility_padding_tensor is not None:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size = self.shape[0] * tp_world_size
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_size),
"constant",
0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]
full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
# if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
# print_rank_0(f'{full_hp_param[:10]=}', force=True)
assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(
key)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]
tp_hp_slice = tp_hp_slice.flatten()
lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
# print(f"{key} SHAPE: {tp_hp_slice.shape=}")
# print(f"{key} SHAPE: {dst_tensor.shape=}")
# print(f"{key} SHAPE: {tp_hp_fragment.shape=}")
dst_tensor.data.copy_(tp_hp_fragment.data)
def enable_universal_checkpoint(param_list):
for param in param_list:
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state,
param)
'''Copyright The Microsoft DeepSpeed Team'''
import os
from .constants import (MODEL_FILE_PREFIX,
MODEL_FILE_SUFFIX,
OPTIM_FILE_SUFFIX,
ZERO_FILE_PREFIX)
def get_model_ckpt_name_for_rank(base_folder, mp_rank_str):
ckpt_name = os.path.join(
base_folder,
MODEL_FILE_PREFIX + mp_rank_str + MODEL_FILE_SUFFIX,
)
return ckpt_name
def get_zero_ckpt_name_for_rank(base_folder, dp_rank, mp_rank):
zero_prefix = f'{ZERO_FILE_PREFIX}{dp_rank}'
mp_rank_string = f'_{MODEL_FILE_PREFIX}{mp_rank:02d}'
zero_ckpt_name = os.path.join(
base_folder,
zero_prefix + mp_rank_string + OPTIM_FILE_SUFFIX,
)
return zero_ckpt_name
def get_layer_ckpt_name_for_rank(base_folder, layer_id, tp_rank):
ckpt_file = f'{layer_id}-model_{tp_rank:02d}{MODEL_FILE_SUFFIX}'
ckpt_path = os.path.join(base_folder, ckpt_file)
return ckpt_path
'''Copyright The Microsoft DeepSpeed Team'''
import torch
from .constants import (BASE_OPTIMIZER_STATE,
GROUP_PADDINGS,
OPTIMIZER_STATE_DICT,
PARTITION_COUNT)
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)
GROUP_STATE_KEY = 'state'
class ZeROCheckpoint(object):
def __init__(self, dir):
basic_folder_validation(dir)
self.dir = dir
self.file_list = get_zero_files(dir)
self.num_files = len(self.file_list)
assert self.num_files > 0, f'No ZeRO files found in {dir}'
self.src_3d = get_model_3d_descriptor(dir)
self.target_3d = model_3d_desc(pp_degree=self.src_3d.pp_degree,
tp_degree=self.src_3d.tp_degree,
dp_degree=self.src_3d.dp_degree)
self._3d_file_map = self.src_3d.reshape(self.target_3d)
def get_src_world_size(self):
return self.src_3d.world_size()
def get_src_tp_degree(self):
return self.src_3d.tp_degree
def get_src_pp_degree(self):
return self.src_3d.pp_degree
def get_src_dp_degree(self):
return self.src_3d.dp_degree
def get_file_indices_for_rank(self, pp_index, tp_index, dp_index):
assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}'
dp_2d_map = self._3d_file_map[dp_index]
return dp_2d_map.get_data(pp_index, tp_index)
def get_files_for_rank(self, pp_index, tp_index, dp_index):
file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
return [self.file_list[idx] for idx in file_idx_list]
def get_state_for_rank(self,
pp_index,
tp_index,
dp_index,
keys_to_ignore=[],
strip_tensor_paddings=True):
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
sd = torch.load(state_file, map_location=torch.device('cpu'))
for key in keys_to_ignore:
sd.pop(key, None)
if strip_tensor_paddings:
self._strip_tensor_paddings(sd)
if merged_sd is None:
merged_sd = sd
else:
merged_sd = merge_state(merged_sd, sd)
self._update_partition_count(merged_sd)
if strip_tensor_paddings:
self._clear_group_paddings(merged_sd)
return merged_sd
def print_3d_index_map(self, tag=None):
if tag:
print(f'3D index map: {tag}')
for dp_index, _2d_map in enumerate(self._3d_file_map):
_2d_map.print_data(f'dp = {dp_index}')
def print_3d_file_map(self, tag=None):
if tag:
print(f'3D file map: {tag}')
for dp_index, _2d_map in enumerate(self._3d_file_map):
for pp_index in _2d_map.pp_degree:
for tp_index in _2d_map.tp_degree:
file_index_list = _2d_map.get_data(pp_index, tp_index)
file_list = [self.file_list[idx] for idx in file_index_list]
print(f'{pp_index}, {tp_index}, {dp_index} => {file_list}')
def reshape(self, target_3d_desc: model_3d_desc):
self.target_3d = target_3d_desc
self._3d_file_map = self.src_3d.reshape(self.target_3d)
def _strip_tensor_paddings(self, sd):
param_group_states = self._get_param_group_states(sd)
if param_group_states is None:
return
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
if group_paddings is None:
return
for key, group_state in param_group_states.items():
if group_paddings[key] == 0:
continue
for state_name, state_value in group_state.items():
if torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value,
0,
0,
raw_length).clone()
def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
if group_paddings:
num_groups = len(group_paddings)
sd[OPTIMIZER_STATE_DICT][GROUP_PADDINGS] = [0] * num_groups
def _get_optimizer_state(self, sd, state_key):
optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
if optimizer_state is None:
return None
return optimizer_state.get(state_key, None)
def _get_param_group_states(self, sd):
optimizer_state = sd.get(OPTIMIZER_STATE_DICT, None)
if optimizer_state is None:
return None
base_optimizer_state = optimizer_state.get(BASE_OPTIMIZER_STATE, None)
if base_optimizer_state is None:
return None
return base_optimizer_state.get(GROUP_STATE_KEY, None)
def _update_partition_count(self, sd):
partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
if partition_counts:
num_groups = len(partition_counts)
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree
] * num_groups
'''Copyright The Microsoft DeepSpeed Team'''
import torch
from .utils import *
from deepspeed import utils
supported_torch_version = False
# See more details at: https://github.com/pytorch/pytorch/pull/48767
# The PG API in torch versions lesser than 1.8 are different so it is
# non-trivial to support both in the same API. We will just use the
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
if older_torch():
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
supported_torch_version = False
from torch.distributed import *
def get_world_group():
return group.WORLD
def get_global_rank(group, group_rank):
if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
else:
from torch.distributed.distributed_c10d import _get_global_rank
return _get_global_rank(group, group_rank)
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import all_gather, get_world_size
from torch import chunk
output_tensors = list(chunk(output_tensor, get_world_size(group)))
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)
def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import reduce_scatter, get_world_size
from torch import chunk
input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
return reduce_scatter(output_tensor, input_tensor_lst, group=group)
def configure(deepspeed_config=None,
enabled=None,
prof_all=None,
prof_ops=None,
verbose=None):
utils.logger.warn(
"Communication logging is not supported in torch versions older than 1.8")
else:
supported_torch_version = True
from .comm import *
'''Copyright The Microsoft DeepSpeed Team'''
''' DeepSpeed Communication Backend.
In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distributed. Simply wrap torch.distributed for now.
# Custom DS Backends -- Direct C/Ops
- NCCL -- [EXPERIMENTAL]
- MPI -- [EXPERIMENTAL]
- RCCL -- [EXPERIMENTAL]
- GLOO -- [EXPERIMENTAL]
# DS backend wrapper for torch.distributed [DEFAULT]
- T-NCCL -- [DEFAULT]
- T-GLOO
- T-MPI
'''
''' Backend is the base class
-- NcclBackend, MpiBackend, and TorchBackend are the main subclasses. TorchBackend is the only officially supported backend for now.
'''
class Backend(object):
def __init__(self, name='backend', rank=0, size=1):
self.name = name
# The world size and rank of the world process group
self.world_group = None
self.world_size = rank
self.world_rank = size
# Single process group (pg) implementation for now but keep a list for future
self.process_groups = []
self.initialized = False
def is_initialized(self):
return self.initialized
def new_group(self):
# create a new pg and add it to pg list
pass
def init_process_group(self):
# subclasses will initialize them fully
# - initialize a default world process group and add it to pg list
self.initialized = True
"""
Copyright 2021 The Microsoft DeepSpeed Team
DeepSpeed Communication Package: deepspeed.comm
deepspeed.comm
-- import and use deepspeeed.ops.comm
-- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
-- use custom collectives
-- can either use torch.dist or ds.ops.comm?
Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
deepspeed.comm API
-- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
-- e.g. if a client code used
from deepspeed import comm as dist
instead of
import torch.distributed as dist
The code should work without breaking any of the public torch.distributed functionality
Future:
-- deepspeed groups API should be brought into ds.comm
"""
from enum import Enum
import torch
import os
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
from .constants import *
from deepspeed.accelerator import get_accelerator
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
AVG = 7
UNUSED = 8
from deepspeed.utils.comms_logging import CommsLogger
from deepspeed.utils import timer, get_caller_func
from deepspeed.comm.torch import TorchBackend
from deepspeed import utils
from datetime import timedelta
# Current deepspeed.comm backend (cdb) global object for simple access by client code
use_ds_backend = False
cdb = None
# Create global timer for ops
timers = timer.SynchronizedWallClockTimer()
timer_summary = {}
comms_logger = CommsLogger()
# Ensure we don't warn about base collectives more than once
has_warned_all_gather = False
has_warned_reduce_scatter = False
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend = None
mpi_backend = None
# This should be set here so all rank/size information from the launcher can be propagated
from deepspeed.comm.utils import *
def _configure_using_config_file(config):
if config.comms_logger_enabled:
comms_logger.configure(config)
def configure(
deepspeed_config=None,
enabled=None,
prof_all=None,
prof_ops=None,
verbose=None,
debug=None,
):
if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config.comms_config)
if enabled is not None:
comms_logger.enabled = enabled
if prof_all is not None:
comms_logger.prof_all = prof_all
if prof_ops is not None:
comms_logger.prof_ops = prof_ops
if verbose is not None:
comms_logger.verbose = verbose
if debug is not None:
comms_logger.debug = debug
# Logging wrapper for timing ops
def timed_op(func):
def log_wrapper(*args, **kwargs):
# Add enabled flag so that overhead to each comm op is two if conditions at most
if comms_logger.enabled:
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
# Need func args for their defaults
func_args = get_default_args(func)
func_args.update(kwargs)
msg_size = get_msg_size_from_args(func, *args, **kwargs)
log_name = get_debug_log_name(func_args, comms_logger.debug)
timers(log_name).start()
# Return the op, then stop the op's timer
try:
return func(*args, **kwargs)
finally:
if comms_logger.enabled:
# Need to make op blocking for accurate logging
get_accelerator().synchronize()
# If we're using MPI, we can't simply sync the stream
if cdb.using_mpi:
cdb.barrier()
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
log_name = get_debug_log_name(func_args, comms_logger.debug)
raw_name = func.__name__
timers(log_name).stop()
# need temp var since 'elapsed' resets events
time_elapsed = timers(log_name).elapsed(reset=False)
comms_logger.append(raw_name, log_name, time_elapsed, msg_size)
return log_wrapper
# For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
# DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
# Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
# UNUSED: Future helper function to initialize DS backends
def init_deepspeed_backend(ds_backend):
global cdb
global nccl_backend
global mpi_backend
global use_ds_backend
if ds_backend == NCCL_BACKEND:
utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
elif ds_backend == MPI_BACKEND:
utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
elif ds_backend == GLOO_BACKEND:
utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
else:
utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
def is_initialized():
#assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb is None:
return False
else:
return cdb.is_initialized()
def destroy_process_group(group=None):
global cdb
return cdb.destroy_process_group(group=group)
def new_group(ranks):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.new_group(ranks)
def is_available() -> bool:
# Returns ``True`` if the deepspeed comm package is available.
# TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
# any communication related primitives from this package.
# use hasattr(deepspeed.csrc.ops, "_comm") or something
return True
def set_backend(backend):
if not use_ds_backend:
utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
)
raise RuntimeError(
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
global cdb
global nccl_backend
global mpi_backend
try:
if backend_name == NCCL_BACKEND:
if nccl_backend is not None and nccl_backend.is_initialized():
cdb = nccl_backend
elif backend_name == MPI_BACKEND:
if mpi_backend is not None and mpi_backend.is_initialized():
cdb = mpi_backend
except Exception as inst:
print(inst)
@timed_op
def broadcast(tensor,
src,
group=None,
async_op=False,
prof=False,
log_name='broadcast',
debug=get_caller_func()):
global cdb
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
@timed_op
def all_gather(tensor_list,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather',
debug=get_caller_func()):
global cdb
return cdb.all_gather(tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)
def has_reduce_scatter_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
return cdb.has_reduce_scatter_base
def reduce_scatter_fn(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
debug=get_caller_func()):
global cdb
global has_warned_reduce_scatter
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_base:
return reduce_scatter_base(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
else:
if not has_warned_reduce_scatter:
utils.logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
has_warned_reduce_scatter = True
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor,
input_tensor_lst,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
@timed_op
def reduce_scatter_base(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter_base',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter_base(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
@timed_op
def all_gather_base(output_tensor,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather_base',
debug=get_caller_func()):
global cdb
return cdb.all_gather_base(output_tensor=output_tensor,
input_tensor=tensor,
group=group,
async_op=async_op)
def has_allgather_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
return cdb.has_allgather_base
def allgather_fn(output_tensor,
input_tensor,
group=None,
async_op=False,
debug=get_caller_func()):
global cdb
global has_warned_all_gather
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_allgather_base:
return all_gather_base(output_tensor,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
else:
if not has_warned_all_gather and get_rank() == 0:
utils.logger.warning(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
has_warned_all_gather = True
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
return all_gather(output_tensors,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
@timed_op
def all_to_all_single(output,
tensor,
output_split_sizes=None,
input_split_sizes=None,
group=None,
async_op=False,
prof=False,
log_name='all_to_all_single',
debug=get_caller_func()):
global cdb
return cdb.all_to_all_single(output=output,
input=tensor,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op)
@timed_op
def send(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='send',
debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def recv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='recv',
debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op
def isend(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='isend',
debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def irecv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='irecv',
debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op
def gather(tensor,
gather_list=None,
dst=0,
group=None,
async_op=False,
prof=False,
log_name='gather',
debug=get_caller_func()):
global cdb
return cdb.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)
@timed_op
def scatter(tensor,
scatter_list=None,
src=0,
group=None,
async_op=False,
prof=False,
log_name='scatter',
debug=get_caller_func()):
global cdb
return cdb.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)
@timed_op
def barrier(group=None,
async_op=False,
device_ids=None,
prof=False,
log_name='barrier',
debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
@timed_op
def monitored_barrier(group=None,
timeout=None,
wait_all_ranks=False,
prof=False,
log_name='monitored_barrier',
debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
def log_summary():
global cdb
barrier(log_name='log_summary_barrier')
if cdb.get_rank() == 0:
comms_logger.log_all()
barrier(log_name='log_summary_barrier')
@timed_op
def reduce(tensor,
dst,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce',
debug=get_caller_func()):
global cdb
return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
@timed_op
def reduce_scatter(output,
input_list,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter(output=output,
input_list=input_list,
op=op,
group=group,
async_op=async_op)
@timed_op
def all_reduce(tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='all_reduce',
debug=get_caller_func()):
#if profile_comm:
# context of the timers?
# timers.start()
# TensorBoard logging for comm calls.?
global cdb
#print(f'op = {op}, cdb= {cdb.name}')
return cdb.all_reduce(tensor, op, group, async_op)
def get_world_group():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_group()
def get_world_size(group=None) -> int:
"""
Returns the number of processes in the current process group
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
The world size of the process group
-1, if not part of the group
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_size(group)
def get_rank(group=None):
"""
Returns the rank of the current process in the provided ``group`` or the
default group if none was provided.
Rank is a unique identifier assigned to each process within a distributed
process group. They are always consecutive integers ranging from 0 to
``world_size``.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
The rank of the process group
-1, if not part of the group
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_rank(group)
def get_local_rank():
"""
Helper function to get local rank after a backend has been set and initialized
Args:
None
Returns:
local rank (= GPU device ID)
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return get_local_rank_from_launcher()
def get_global_rank(group=None, group_rank=0):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_global_rank(group, group_rank)
# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True,
timeout=default_pg_timeout,
init_method=None,
dist_init_required=None,
config=None,
rank=-1,
world_size=-1):
''' Initialize dist backend, potentially performing MPI discovery if needed
Arguments:
dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
distributed_port: Optional (int). torch distributed backend port
verbose: Optional (bool). verbose logging
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
'''
global cdb
configure(deepspeed_config=config)
if dist_init_required is None:
dist_init_required = cdb is None or not cdb.is_initialized()
if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
cdb = TorchBackend(dist_backend, timeout, init_method)
return
if dist_init_required is False:
assert (
cdb is not None and cdb.is_initialized() is True
), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
else:
# Initialize torch distributed if needed
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
utils.logger.info(
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose)
elif in_aws_sm():
patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
else:
mpi_discovery(distributed_port=distributed_port, verbose=verbose)
if cdb is not None and cdb.is_initialized():
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info('Distributed backend already initialized')
else:
assert isinstance(timeout, timedelta)
if dist_backend == None:
dist_backend = get_accelerator().communication_backend_name()
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info(
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
'''
Discovery MPI environment via mpi4py and map to relevant dist state
'''
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(distributed_port)
if verbose:
utils.logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if cdb is not None and cdb.is_initialized():
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
rank, cdb.get_rank())
assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, cdb.get_world_size())
def in_aml():
# Are we running inside an Azure Machine Learning (AML) environment?
return 'AZUREML_EXPERIMENT_ID' in os.environ
def in_aws_sm():
# Are we running inside an AWS SageMaker environment?
return 'SM_TRAINING_ENV' in os.environ
def in_dlts():
# Are we running on a DLTS cluster?
return 'DLTS_JOB_ID' in os.environ
def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""Helper routine to get and set environment variables.
This is adapted from Azure ML's documentation available from:
https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
os.environ["MASTER_ADDR"] = master_node_params[0]
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = str(master_port)
else:
os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
if verbose:
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
if verbose:
utils.logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
"""Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
if verbose:
utils.logger.info(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
'''Copyright The Microsoft DeepSpeed Team'''
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from pydantic import BaseModel
from .constants import *
class CommsConfig(BaseModel):
class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
class CommsLoggerConfig(CommsConfig):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
debug: bool = COMMS_LOGGER_DEBUG_DEFAULT
class DeepSpeedCommsConfig:
def __init__(self, ds_config):
self.comms_logger_enabled = 'comms_logger' in ds_config
if self.comms_logger_enabled:
self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])
'''Copyright The Microsoft DeepSpeed Team'''
NCCL_BACKEND = 'nccl'
MPI_BACKEND = 'mpi'
GLOO_BACKEND = 'gloo'
SCCL_BACKEND = 'sccl'
DEFAULT_AML_MASTER_PORT = "54965"
DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo"
#########################################
# Comms Logger
#########################################
# Comms Logger. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
COMMS_LOGGER_FORMAT = '''
The Comms Logger can be specified as:
"comms_logger": {
"enabled": true,
"verbose": false,
"prof_all": true,
"debug": false,
"prof_ops": ["all_reduce", "custom_all_reduce_name"]
}
'''
COMMS_LOGGER = "comms_logger"
# Comms logger enable signal
COMMS_LOGGER_ENABLED = "enabled"
COMMS_LOGGER_ENABLED_DEFAULT = False
# Comms logger verbose signal
COMMS_LOGGER_VERBOSE = "verbose"
COMMS_LOGGER_VERBOSE_DEFAULT = False
# comms logger profile all ops signal
COMMS_LOGGER_PROF_ALL = "prof_all"
COMMS_LOGGER_PROF_ALL_DEFAULT = True
# comms logger show all ops signal
COMMS_LOGGER_DEBUG = "debug"
COMMS_LOGGER_DEBUG_DEFAULT = False
# comms logger profile specific ops in list
COMMS_LOGGER_PROF_OPS = "prof_ops"
COMMS_LOGGER_PROF_OPS_DEFAULT = []
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
from deepspeed import utils
from .utils import *
from .backend import *
from .comm import *
class TorchBackend(Backend):
"""
A light-weight wrapper class for torch.distributed API.
Only a subset of functions are wrapped. Once the init_process_group
is initialized, standard torch.distributed.* can be used directly
so no need to wrap all the functions. We can keep adding wrappers as
needed.
"""
def __init__(self,
backend,
timeout,
init_method,
rank=-1,
world_size=-1,
name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.has_allgather_base = has_allgather_base()
self.has_reduce_scatter_base = has_reduce_scatter_base()
self.initialized = True
self.name = name
# Future functionality to support ds.initialize() on a single GPU
# The idea is to fake that dist backend is initialized even when
# it is not so we can run on a single GPU without doing any init_process_group
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size)
def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend,
timeout=timeout,
init_method=init_method,
rank=rank,
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'
def all_reduce(self,
tensor,
op=torch.distributed.ReduceOp.SUM,
group=None,
async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor,
op=op,
group=group,
async_op=async_op)
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce(tensor=tensor,
dst=dst,
op=self._reduce_op(op),
group=group,
async_op=async_op)
def reduce_scatter(self,
output,
input_list,
op=ReduceOp.SUM,
group=None,
async_op=False):
return torch.distributed.reduce_scatter(output=output,
input_list=input_list,
op=self._reduce_op(op),
group=group,
async_op=async_op)
def broadcast(self, tensor, src, group=None, async_op=False):
return torch.distributed.broadcast(tensor=tensor,
src=src,
group=group,
async_op=async_op)
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return torch.distributed.all_gather(tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_allgather_base:
return torch.distributed.distributed_c10d._all_gather_base(
output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
else:
utils.logger.warning(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def reduce_scatter_base(self,
output_tensor,
input_tensor,
op=ReduceOp.SUM,
group=None,
async_op=False):
if self.has_reduce_scatter_base:
return torch.distributed._reduce_scatter_base(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
else:
utils.logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def all_to_all_single(self,
output,
input,
output_split_sizes=None,
input_split_sizes=None,
group=None,
async_op=False):
return torch.distributed.all_to_all_single(output=output,
input=input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=async_op)
def send(self, tensor, dst, group=None, tag=0):
return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
def recv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)
def isend(self, tensor, dst, group=None, tag=0):
return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)
def irecv(self, tensor, src=None, group=None, tag=0):
return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)
def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
return torch.distributed.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)
def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
return torch.distributed.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)
def barrier(self,
group=torch.distributed.GroupMember.WORLD,
async_op=False,
device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group,
async_op=async_op,
device_ids=device_ids)
def monitored_barrier(self,
group=torch.distributed.GroupMember.WORLD,
timeout=None,
wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.monitored_barrier(group=group,
timeout=timeout,
wait_all_ranks=wait_all_ranks)
def get_rank(self, group=None):
return torch.distributed.get_rank(group=group)
def get_world_size(self, group=None):
return torch.distributed.get_world_size(group=group)
def is_initialized(self):
return torch.distributed.is_initialized()
def get_backend(self, group=None):
return torch.distributed.get_backend(group=group)
def new_group(self, ranks):
return torch.distributed.new_group(ranks)
def get_global_rank(self, group, group_rank):
if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
else:
from torch.distributed.distributed_c10d import _get_global_rank
return _get_global_rank(group, group_rank)
def get_world_group(self):
return torch.distributed.group.WORLD
def destroy_process_group(self, group=None):
return torch.distributed.destroy_process_group(group=group)
def _reduce_op(self, op):
'''
Helper function. If the op provided is not a torch.dist.ReduceOp, convert it and return
'''
if not isinstance(op, torch.distributed.ReduceOp):
if op == ReduceOp.SUM:
op = torch.distributed.ReduceOp.SUM
elif op == ReduceOp.PRODUCT:
op = torch.distributed.ReduceOp.PRODUCT
elif op == ReduceOp.AVG:
op = torch.distributed.ReduceOp.AVG
elif op == ReduceOp.MIN:
op = torch.distributed.ReduceOp.MIN
elif op == ReduceOp.MAX:
op = torch.distributed.ReduceOp.MAX
elif op == ReduceOp.BAND:
op = torch.distributed.ReduceOp.BAND
elif op == ReduceOp.BOR:
op = torch.distributed.ReduceOp.BOR
elif op == ReduceOp.BXOR:
op = torch.distributed.ReduceOp.BXOR
return op
# This will become a light-weight wrapper around torch.distributed functions
# TODO: create some example to show how this wrapper can help profile communication
# TODO: make sure there is no performance regression with this approach
# TODO: explore monkey-patching if this does not work
'''Copyright The Microsoft DeepSpeed Team'''
import os
import torch
import inspect
from deepspeed.utils import get_caller_func
def older_torch():
'''
Helper to lookup torch version. For versions less than 1.8, torch.dist
used torch.distributed.group.WORLD as the default group argument instead of None.
See more details at: https://github.com/pytorch/pytorch/pull/48767
'''
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
return True
else:
return False
def has_allgather_base():
'''
Helper to check if torch.distributed has _all_gather_base
'''
return hasattr(torch.distributed, "_all_gather_base")
def has_reduce_scatter_base():
'''
Helper to check if torch.distributed has _reduce_scatter_base
'''
return hasattr(torch.distributed, "_reduce_scatter_base")
def get_local_rank_from_launcher():
# DeepSpeed launcher will set it so get from there
rank = os.environ.get('LOCAL_RANK')
if rank is None:
rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')
# Make it a single process job and set rank to 0
if rank is None:
rank = 0
return int(rank)
def get_world_rank_from_launcher():
# DeepSpeed launcher will set it so get from there
rank = os.environ.get('RANK')
if rank is None:
rank = os.environ.get('OMPI_COMM_WORLD_RANK')
# Make it a single process job and set rank to 0
if rank is None:
rank = 0
return int(rank)
def get_world_size_from_launcher():
# DeepSpeed launcher will set it so get from there
size = os.environ.get('WORLD_SIZE')
rank = os.environ.get('RANK')
if size is None:
size = os.environ.get('OMPI_COMM_WORLD_SIZE')
# Make it a single process job and set size to 1
if size is None:
size = 1
if rank == 0:
print(f"set world size to {size}")
return int(size)
def get_default_args(func):
signature = inspect.signature(func)
return {
k: v.default
for k,
v in signature.parameters.items() if v.default is not inspect.Parameter.empty
}
# We need this hacky function since torch doesn't consistently name or place the input tensor args
def get_tensor_position(func):
sig_params = inspect.signature(func).parameters
arg = None
# most colls
if 'tensor' in sig_params:
arg = 'tensor'
# reduce scatter coll
elif 'input_list' in sig_params:
arg = 'input_list'
# all_to_all and torch multiGPU colls
elif 'input_tensor_list' in sig_params:
arg = 'input_tensor_list'
if arg is None:
return -1
else:
return list(sig_params).index(arg)
def get_tensor_kwarg(func, kwargs):
func_args = get_default_args(func)
func_args.update(kwargs)
arg = None
if 'tensor' in func_args:
arg = func_args['tensor']
elif 'input_list' in func_args:
arg = func_args['input_list']
elif 'input_tensor_list' in func_args:
arg = func_args['input_tensor_list']
return arg
def get_msg_size_from_args(func, *args, **kwargs):
# 3 cases:
# - tensor arg is in args
# - tensor arg is in kwargs
# - tensor arg is not present (e.g. barrier)
tensor_arg_position = -1
tensor_arg = None
# check if tensor arg is in args
if len(args) > 0:
tensor_arg_position = get_tensor_position(func)
if tensor_arg_position > -1:
tensor_arg = args[get_tensor_position(func)]
# check if tensor arg is in kwargs
if tensor_arg is None and len(kwargs) > 0:
tensor_arg = get_tensor_kwarg(func, kwargs)
# if tensor arg is not present, no data is being transmitted
if tensor_arg is None:
return 0
else:
# Sum of tensor sizes for list colls such as torch's all_to_all
# NOTE: msg_size for list colls will not be the actual size transmitted by a given MPI/NCCL call within the coll op. Instead, it's the total amount of data transmitted.
if type(tensor_arg) is list:
return sum(x.element_size() * x.nelement() for x in tensor_arg)
else:
return tensor_arg.element_size() * tensor_arg.nelement()
def get_debug_log_name(func_args, debug):
if debug:
return func_args['log_name'] + ' | [Caller Func: ' + get_caller_func() + ']'
else:
return func_args['log_name']
'''Copyright The Microsoft DeepSpeed Team'''
from .compress import init_compression, redundancy_clean
from .scheduler import compression_scheduler
from .helper import convert_conv1d_to_linear
'''Copyright The Microsoft DeepSpeed Team'''
import torch
import math
from torch import nn
from torch.nn import init
import deepspeed.comm as dist
from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
from deepspeed.utils import logger
g_mpu = None
class QuantAct(nn.Module):
"""
Class to quantize given activations. Note that when using this function, the input acttivation quantization range will be fixed for all
tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
Parameters:
----------
act_range_momentum : float, default 0.95
Momentum for updating the activation quantization range.
quant_mode : str, default 'symmetric'
"""
def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
super(QuantAct, self).__init__()
self.act_range_momentum = act_range_momentum
self.quant_mode = quant_mode
if quant_mode == 'symmetric':
self.act_function = SymQuantizer.apply
else:
self.act_function = AsymQuantizer.apply
self.register_buffer('x_min_max', torch.zeros(2))
def forward(self, x, num_bits, *args):
"""
x: the activation that we need to quantize
num_bits: the number of bits we need to quantize the activation to
*args: some extra arguments that are useless but needed for align with the interface of other quantization functions
"""
if self.training:
x_min = x.data.min()
x_max = x.data.max()
# Initialization
if self.x_min_max[0] == self.x_min_max[1]:
self.x_min_max[0] = x_min
self.x_min_max[1] = x_max
# if do not need momentum, please set self.act_range_momentum = 0
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (
1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (
1 - self.act_range_momentum)
x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
return x_q
class Embedding_Compress(nn.Embedding):
def __init__(self, *kargs):
super(Embedding_Compress, self).__init__(*kargs)
self.weight.start_bits = None
self.weight.target_bits = None
self.weight.q_period = None
self.weight_quantization_enabled_in_forward = False
self.weight_quantization_enabled = False
def extra_repr(self):
return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
self.num_embeddings,
self.embedding_dim,
self.weight.target_bits)
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
if self.weight_quantization_enabled_in_forward:
logger.warning(
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
)
if self.weight.target_bits >= 3:
if quantization_type == 'symmetric':
self.weight_quantizer = SymQuantizer.apply
else:
self.weight_quantizer = AsymQuantizer.apply
elif self.weight.target_bits == 2:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
self.weight_quantizer = TernaryQuantizer.apply
elif self.weight.target_bits == 1:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
self.weight_quantizer = BinaryQuantizer.apply
# for embedding, we always use token-wise quantization
self.weight_quantize_num_groups = self.weight.size(0)
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups)
else:
weight = self.weight
out = nn.functional.embedding(input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse)
return out
class LinearLayer_Compress(nn.Linear):
"""
Linear layer with compression.
"""
def __init__(self, *kargs, bias=True):
super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
self.sparse_pruning_method = None
self.row_pruning_method = None
self.head_pruning_method = None
self.activation_quantization_method = None
self.weight.start_bits = None
self.weight.target_bits = None
self.weight.q_period = None
self.weight_quantization_enabled_in_forward = False
self.weight_quantization_enabled = False
self.sparse_pruning_enabled = False
self.row_pruning_enabled = False
self.head_pruning_enabled = False
self.activation_quantization_enabled = False
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method):
# Here, we support two cases: L1 norm based pruning and topk based pruning
self.sparse_pruning_ratio = ratio
self.sparse_pruning_method = method
if method == 'l1':
weight_norm = torch.abs(self.weight.data)
mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
mask = mask.view(self.weight.size())
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('sparse_pruning_mask', mask)
def enable_row_pruning(self, ratio, method):
# Here, we support two cases: L1 norm based pruning and topk based pruning
self.row_pruning_ratio = ratio
self.row_pruning_method = method
if method == 'l1':
# compute the l1 norm of each column
weight_norm = torch.norm(self.weight.data, p=1, dim=1)
mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
mask = mask.view(-1, 1)
mask = mask.to(self.weight.device)
elif method == 'topk':
self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('row_pruning_mask', mask)
def enable_head_pruning(self, ratio, method, num_heads):
# Here, we support only topk based pruning
self.num_heads = num_heads
self.head_pruning_ratio = ratio
self.head_pruning_method = method
if method not in ['topk']:
raise NotImplementedError
else:
self.head_pruning_ratio = ratio
self.head_pruning_scores = nn.Parameter(torch.Tensor(
1,
self.num_heads)) # we apply the pruning to O matrix
self.head_pruning_scores.data = self.head_pruning_scores.data.to(
self.weight.device)
init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
def fix_sparse_pruning_helper(self):
mask = self.get_mask(pruning_type='sparse')
self.weight.data = self.weight.data * mask
del self.sparse_pruning_mask
if self.sparse_pruning_method == 'topk':
del self.sparse_mask_scores
self.sparse_pruning_method = None
self.sparse_pruning_enabled = False
return None
def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
# This function is used for row/col pruning
# particularly, if we have two back-to-back layers, F1 and F2; when
# we remove rows from F1, we also need to remove columns from F2
# However, if we only have one layer, F1, then we only need to mask pruned
# rows as 0 in F1
if mask is None:
mask = self.get_mask(pruning_type='row').bool()
if dim_reduction:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
if self.bias is not None:
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
self.out_features = self.weight.size(0)
else:
self.weight.data = self.weight.data * mask.view(-1, 1)
if self.bias is not None:
self.bias.data = self.bias.data * mask.view(-1)
del self.row_pruning_mask
if self.row_pruning_method == 'topk':
del self.row_mask_scores
self.row_pruning_method = None
else:
# this is generally for column pruning
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
self.in_features = self.weight.size(1)
mask = None
self.row_pruning_enabled = False
return mask
def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
# similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
num_heads = num_heads if num_heads else self.num_heads
if mask is None:
if self.head_pruning_method == 'topk':
mask = self.get_mask(pruning_type='head').bool()
if dim_reduction:
shape = self.weight.size(0)
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape).t())
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
else:
shape = self.weight.size()
self.weight.data = (self.weight.data.t().reshape(self.num_heads,
-1) *
mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
if self.head_pruning_method == 'topk':
del self.head_pruning_scores
self.head_pruning_method = None
else:
raise NotImplementedError
else:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
shape = self.weight.size(1)
self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
if self.bias is not None:
self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
self.head_pruning_enabled = False
return mask
def get_mask(self, pruning_type='row'):
if pruning_type == 'sparse':
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores,
self.sparse_pruning_ratio,
False)
else:
raise NotImplementedError
if pruning_type == 'row':
if self.row_pruning_method == 'l1':
return self.row_pruning_mask.to(self.weight.device)
elif self.row_pruning_method == 'topk':
return TopKBinarizer.apply(self.row_mask_scores,
self.row_pruning_ratio,
False)
else:
raise NotImplementedError
elif pruning_type == 'head':
if self.head_pruning_method == 'topk':
return TopKBinarizer.apply(self.head_pruning_scores,
self.head_pruning_ratio,
False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
if self.weight_quantization_enabled_in_forward:
logger.warning(
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
)
if self.weight.target_bits >= 3:
if quantization_type == 'symmetric':
self.weight_quantizer = SymQuantizer.apply
else:
self.weight_quantizer = AsymQuantizer.apply
elif self.weight.target_bits == 2:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
self.weight_quantizer = TernaryQuantizer.apply
elif self.weight.target_bits == 1:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
self.weight_quantizer = BinaryQuantizer.apply
self.weight_quantize_num_groups = num_groups
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def enable_activation_quantization(self, bits, quantization_type, range_calibration):
assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
self.activation_quantization_bits = bits
self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
if range_calibration == 'static':
self.activation_quantizer = QuantAct(quant_mode=quantization_type)
else:
if quantization_type == 'symmetric':
self.activation_quantizer = SymQuantizer.apply
else:
self.activation_quantizer = AsymQuantizer.apply
def head_pruning_reshape(self, w, mask):
shape = w.shape
return (w.t().reshape(self.num_heads,
-1) * mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
def forward(self, input, skip_bias_add=False):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups)
bias = self.bias
else:
weight = self.weight
bias = self.bias
if self.sparse_pruning_enabled and self.sparse_pruning_method:
mask = self.get_mask(pruning_type='sparse')
weight = weight * mask.view(self.weight.size())
if self.row_pruning_enabled and self.row_pruning_method:
mask = self.get_mask(pruning_type='row')
weight = weight * mask.view(-1, 1)
if bias is not None:
bias = bias * mask.view(-1)
if self.head_pruning_enabled and self.head_pruning_method:
mask = self.get_mask(pruning_type='head')
weight = self.head_pruning_reshape(weight, mask)
if self.activation_quantization_enabled:
if 'dynamic' in self.activation_quantization_method:
num_groups = input.numel() // input.size(-1)
else:
num_groups = 1
input = self.activation_quantizer(input,
self.activation_quantization_bits,
None,
None,
num_groups)
if skip_bias_add:
# used for mpu linear layers
output = nn.functional.linear(input, weight, None)
return output, bias
else:
output = nn.functional.linear(input, weight, bias)
return output
class Conv2dLayer_Compress(nn.Conv2d):
"""
Conv2D layer with compression.
"""
def __init__(self, *kargs):
super(Conv2dLayer_Compress, self).__init__(*kargs)
self.sparse_pruning_method = None
self.channel_pruning_method = None
self.activation_quantization_method = None
self.weight.start_bits = None
self.weight.target_bits = None
self.weight.q_period = None
self.weight_quantization_enabled_in_forward = False
self.sparse_pruning_enabled = False
self.channel_pruning_enabled = False
self.activation_quantization_enabled = False
def __repr__(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0, ) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
output = s.format(**self.__dict__)
return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
self.sparse_pruning_method is not None,
self.channel_pruning_method is not None,
self.activation_quantization_method is not None,
self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method):
self.sparse_pruning_ratio = ratio
self.sparse_pruning_method = method
if method == 'l1':
weight_norm = torch.abs(self.weight.data)
mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
mask = mask.view(self.weight.size())
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('sparse_pruning_mask', mask)
def enable_channel_pruning(self, ratio, method):
# Here, we support two cases: L1 norm based pruning and topk based pruning
self.channel_pruning_ratio = ratio
self.channel_pruning_method = method
if method == 'l1':
# compute the l1 norm of each conv2d kernel (the last three dimension)
weight_norm = torch.norm(self.weight.data, p=1, dim=[1, 2, 3])
mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
mask = mask.view(-1, 1, 1, 1)
mask = mask.to(self.weight.device)
elif method == 'topk':
self.channel_mask_scores = nn.Parameter(
torch.Tensor(self.weight.size(0),
1,
1,
1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(
self.weight.device)
init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('channel_pruning_mask', mask)
def fix_sparse_pruning_helper(self):
mask = self.get_mask(pruning_type='sparse')
self.weight.data = self.weight.data * mask
del self.sparse_pruning_mask
if self.sparse_pruning_method == 'topk':
del self.sparse_mask_scores
self.sparse_pruning_method = None
self.sparse_pruning_enabled = False
return None
def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
if mask is None:
if self.channel_pruning_method in ['l1', 'topk']:
mask = self.get_mask(pruning_type='channel').bool()
if dim_reduction:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
if self.bias is not None:
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
else:
self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
if self.bias is not None:
self.bias.data = self.bias.data * mask.view(-1)
del self.channel_pruning_mask
if self.channel_pruning_method == 'topk':
del self.channel_mask_scores
self.channel_pruning_method = None
else:
raise NotImplementedError
else:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
mask = None
self.channel_pruning_enabled = False
return mask
def get_mask(self, pruning_type='sparse'):
if pruning_type == 'sparse':
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores,
self.sparse_pruning_ratio,
False)
else:
raise NotImplementedError
elif pruning_type == 'channel':
if self.channel_pruning_method == 'l1':
return self.channel_pruning_mask.to(self.weight.device)
elif self.channel_pruning_method == 'topk':
return TopKBinarizer.apply(self.channel_mask_scores,
self.channel_pruning_ratio,
False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
if self.weight_quantization_enabled_in_forward:
assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
logger.warning(
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
)
if quantization_type == 'symmetric':
self.weight_quantizer = SymQuantizer.apply
else:
self.weight_quantizer = AsymQuantizer.apply
self.weight_quantize_num_groups = num_groups
def enable_activation_quantization(self, bits, quantization_type, range_calibration):
assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
self.activation_quantization_bits = bits
self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
if range_calibration == 'static':
self.activation_quantizer = QuantAct(quant_mode=quantization_type)
else:
if quantization_type == 'symmetric':
self.activation_quantizer = SymQuantizer.apply
else:
self.activation_quantizer = AsymQuantizer.apply
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups)
bias = self.bias
else:
weight = self.weight
bias = self.bias
if self.sparse_pruning_enabled and self.sparse_pruning_method:
mask = self.get_mask(pruning_type='sparse')
weight = weight * mask.view(self.weight.size())
if self.channel_pruning_enabled:
mask = self.get_mask(pruning_type='channel')
weight = weight * mask.view(-1, 1, 1, 1)
if bias is not None:
bias = bias * mask.view(-1)
if self.activation_quantization_enabled:
if 'dynamic' in self.activation_quantization_method:
num_groups = input.numel() // input[0].numel()
else:
num_groups = 1
input = self.activation_quantizer(input,
self.activation_quantization_bits,
None,
None,
num_groups)
return nn.functional.conv2d(input,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups)
class BNLayer_Compress(nn.BatchNorm2d):
def fix_channel_pruning_helper(self, mask, dim_reduction=True):
self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
self.running_mean = self.running_mean[mask.view(-1)]
self.running_var = self.running_var[mask.view(-1)]
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
group = g_mpu.get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
dist.all_reduce(input_, group=group)
return input_
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
assert tensor.size()[last_dim] % num_partitions == 0
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
group = g_mpu.get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# Split along last dimension.
world_size = dist.get_world_size(group=group)
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous()
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
group = g_mpu.get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_):
return _split(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
class ColumnParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu,
input_size,
output_size,
bias=True,
gather_output=True,
skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
# Divide the weight matrix along the last dimension.
world_size = mpu.get_model_parallel_world_size()
assert output_size % world_size == 0
self.output_size_per_partition = output_size // world_size
super(ColumnParallelLinear_Compress,
self).__init__(self.input_size,
self.output_size_per_partition,
bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
if self.skip_bias_add:
output_parallel, bias = super().forward(input_parallel, True)
else:
output_parallel = super().forward(input_parallel)
bias = None
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output, bias
class RowParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu,
input_size,
output_size,
bias=True,
input_is_parallel=False,
skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.skip_bias_add = skip_bias_add
# Divide the weight matrix along the last dimension.
world_size = mpu.get_model_parallel_world_size()
assert input_size % world_size == 0
self.input_size_per_partition = input_size // world_size
super(RowParallelLinear_Compress,
self).__init__(self.input_size_per_partition,
self.output_size,
bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel, bias = super().forward(input_parallel, True)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if not self.skip_bias_add:
if bias is not None:
output = output_ + bias
else:
output = output_
output_bias = None
else:
output = output_
output_bias = bias
return output, output_bias
'''Copyright The Microsoft DeepSpeed Team'''
import re
from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible
from .config import get_compression_config
from ..runtime.config_utils import dict_raise_error_on_duplicate_keys
from .constants import *
import os
import json
def check_deepspeed_config(config):
if isinstance(config, dict):
return config
elif os.path.exists(config):
return json.load(open(config,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
)
def get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=None,
verbose=True):
'''
get the associated module name from the model based on the key_word provided by users
'''
return_module_name = []
for name, module in model.named_modules():
module_check = is_module_compressible(module, mpu)
if re.search(key_word, name) is not None and module_check:
if name in exist_module_name and verbose:
# logger.warning
raise ValueError(
f"{name} is already added to compression, please check your config file for {group_name}."
)
if name not in exist_module_name:
exist_module_name.add(name)
return_module_name.append(name)
return return_module_name, exist_module_name
def get_compress_methods(model, compress_methods, mpu=None):
# extract the compression module for each method in compress_methods
layer_added_compress_methods = []
for method, method_content in compress_methods.items():
if LAYER_REDUCTION in method:
continue
# for loop different methods, i.e., weight quantization, activation quantization etc
exist_module_name = set()
shared_parameters = method_content[
SHARED_PARAMETERS] # get all the shared parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
# for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
module_name_list = []
related_module_name_list = []
if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]:
# this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
# otherwise we just mask those as zeros
for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE], method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
module_name_list.append(module_name)
tmp_related_module_name_list = []
for rkw in related_key_words:
# related key word can be a list, for instance the QKV for O matrix in Attention
module_name, _ = get_module_name(group_name, model, rkw, set(), mpu=mpu)
tmp_related_module_name_list.append(module_name)
related_module_name_list.append(tmp_related_module_name_list)
else:
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
module_name_list.append(module_name)
if module_name_list:
# combine shared parameters with each group
combined_method_parameters = {
**(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)),
**shared_parameters
}
compression_item = [
module_name_list,
related_module_name_list,
{
method: combined_method_parameters
}
]
layer_added_compress_methods.append(compression_item)
return layer_added_compress_methods
def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
"""
Compress a model: replace linear/conv2d layer with deepspeed compression-aware modules
Args:
model (`torch.nn.Module`)
The model to compress.
deepspeed_config (`DeepSpeedConfig`)
The path of ds_config
mpu
The mpu module for Row/Column parallelism
"""
compress_methods = get_compression_config(check_deepspeed_config(deepspeed_config))
if hasattr(model, 'module'):
c_model = model.module
else:
c_model = model
# For layer reduction
if compress_methods[LAYER_REDUCTION][LAYER_REDUCTION_ENABLED]:
assert teacher_model is not None, "Teacher model is required for layer reduction"
student_initialization(c_model, teacher_model, deepspeed_config)
layer_added_compress_methods = get_compress_methods(c_model,
compress_methods,
mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu)
return model
def redundancy_clean(model, deepspeed_config, mpu=None):
"""
Remove the redundancy of a model
Args:
model (`torch.nn.Module`)
The model to compress.
deepspeed_config (`DeepSpeedConfig`)
The path of ds_config
mpu
The mpu module for Row/Column parallelism
"""
compress_methods = get_compression_config(check_deepspeed_config(deepspeed_config))
if hasattr(model, 'module'):
c_model = model.module
else:
c_model = model
layer_added_compress_methods_tmp = get_compress_methods(c_model,
compress_methods,
mpu=mpu)
# sort methods
order_list = [
WEIGHT_QUANTIZATION,
SPARSE_PRUNING,
ROW_PRUNING,
HEAD_PRUNING,
CHANNEL_PRUNING,
ACTIVATION_QUANTIZATION
]
layer_added_compress_methods = sorted(
layer_added_compress_methods_tmp,
key=lambda x: order_list.index(list(x[2].keys())[0]))
for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods:
stored_mask = []
need_mask = True if related_module_name_lists else False
for i, mnl in enumerate(module_name_lists):
for module_name in mnl:
mask = fix_compression(c_model,
module_name,
compression_technique,
dim_reduction=need_mask)
if need_mask:
stored_mask.append(mask)
if need_mask:
for rmnl in related_module_name_lists[i]:
for j, module_name in enumerate(rmnl):
mask = fix_compression(c_model,
module_name,
compression_technique,
mask=stored_mask[j],
dim_reduction=True)
return model
def student_initialization(student_model, teacher_model, deepspeed_config):
'''
Given a student model and a teacher model, select the
Args:
student_model (`torch.nn.Module`)
The model we will update weight
teacher_model (`torch.nn.Module`)
The model guide the student to learn
deepspeed_config (`DeepSpeedConfig`)
The path of ds_config
'''
config = get_compression_config(check_deepspeed_config(deepspeed_config))
compress_methods = config[LAYER_REDUCTION]
module_name_prefix = compress_methods[MODULE_NAME_PREFIX]
teacher_layer = compress_methods[TEACHER_LAYER]
student_layer = [i for i in range(len(teacher_layer))]
other_module_name = compress_methods[OTHER_MODULE_NAME]
'''
name_prefix (`str`)
The prefix name before the layer #.
Example 1: bert.encoder.layer, for BERT_base model's prefix name
Example 2: transformer.h, for GPT-2 hugging face prefix name
teacher_layer (`list of intergers`)
The layer of teacher will be used for student's reinitializedion
Example 1: [1,3,5,7,9], means we want to matches the 2nd/4th/6th/8th/10th layer of teacher to the first 5 layers of student
student_layer (`list` or None)
The layer of student need to be re-intiialized
Example 1: None, means we want to reinitialize all the layers
Example 1: [0,1,2,3,4], means we want to reinitialize the first 5 layers
other_module_name (`list of string`)
The modules will be used for student's reinitializedion
Example 1: ['bert.pooler', 'bert.embeddings', 'classifier'], means we want to apply the weight in teacher's embedding/pooler/classier module to the student
Example 2: ['transformer.w', 'transformer.ln_f', 'lm_head'], means we want to apply the weight in teacher's embeddingn layers module to the student
Note that teacher_layer should matches student layer
'''
assert len(student_layer) == len(teacher_layer)
for s_name, t_name in zip(student_layer, teacher_layer):
s_module = recursive_getattr(student_model,
module_name_prefix + '.' + str(s_name))
t_module = recursive_getattr(teacher_model,
module_name_prefix + '.' + str(t_name))
for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
s_param.data.copy_(t_param.data)
for name in other_module_name:
s_module = recursive_getattr(student_model, name)
t_module = recursive_getattr(teacher_model, name)
print(name)
for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
s_param.data.copy_(t_param.data)
'''Copyright The Microsoft DeepSpeed Team'''
from .constants import *
import copy
from ..runtime.config_utils import get_scalar_param
def get_compression_config(param_dict):
#
output = {}
if COMPRESSION_TRAINING not in param_dict.keys():
param_dict[COMPRESSION_TRAINING] = {}
sub_param_dict = param_dict[COMPRESSION_TRAINING]
output[WEIGHT_QUANTIZATION] = get_weight_quantization(sub_param_dict)
output[ACTIVATION_QUANTIZATION] = get_activation_quantization(sub_param_dict)
output[SPARSE_PRUNING] = get_sparse_pruning(sub_param_dict)
output[ROW_PRUNING] = get_row_pruning(sub_param_dict)
output[HEAD_PRUNING] = get_head_pruning(sub_param_dict)
output[CHANNEL_PRUNING] = get_channel_pruning(sub_param_dict)
output[LAYER_REDUCTION] = get_layer_reduction(sub_param_dict)
return output
def get_layer_reduction(param_dict):
output = {}
output[LAYER_REDUCTION_ENABLED] = LAYER_REDUCTION_ENABLED_DEFAULT
if get_layer_reduction_enabled(param_dict):
output[LAYER_REDUCTION_ENABLED] = get_layer_reduction_enabled(param_dict)
for key, val in get_layer_reduction_params(param_dict).items():
output[key] = val
return output
def get_layer_reduction_enabled(param_dict):
if LAYER_REDUCTION in param_dict.keys():
return get_scalar_param(param_dict[LAYER_REDUCTION],
LAYER_REDUCTION_ENABLED,
LAYER_REDUCTION_ENABLED_DEFAULT)
else:
return False
def get_layer_reduction_params(param_dict):
if LAYER_REDUCTION in param_dict.keys():
layer_reduction_params = copy.copy(param_dict[LAYER_REDUCTION])
layer_reduction_params.pop(LAYER_REDUCTION_ENABLED)
return layer_reduction_params
else:
return False
def get_quantize_enabled(param_dict):
if COMPRESSION_TRAINING not in param_dict.keys():
return False
sub_param_dict = param_dict[COMPRESSION_TRAINING]
output = get_weight_quantization_shared_parameters(sub_param_dict)
return output[WEIGHT_QUANTIZE_ENABLED]
def get_weight_quantization(param_dict):
output = {}
if WEIGHT_QUANTIZATION not in param_dict.keys():
param_dict[WEIGHT_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[WEIGHT_QUANTIZATION]
# shared parameters
output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
return output
def get_weight_quantization_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_KERNEL,
WEIGHT_QUANTIZE_KERNEL_DEFAULT)
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_VERBOSE,
WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict,
WEIGHT_QUANTIZE_TYPE,
WEIGHT_QUANTIZE_TYPE_DEFAULT)
output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_ROUNDING,
WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
WEIGHT_QUANTIZE_CHANGE_RATIO,
WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
else:
output[
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
else:
output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_KERNEL] = WEIGHT_QUANTIZE_KERNEL_DEFAULT
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
output[WEIGHT_QUANTIZE_GROUPS] = WEIGHT_QUANTIZE_GROUPS_DEFAULT
output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
output[
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
return output
def get_weight_quantization_different_groups(param_dict):
output = {}
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(
group_dict,
WEIGHT_QUANTIZATION_PERIOD,
WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
def get_activation_quantization(param_dict):
output = {}
if ACTIVATION_QUANTIZATION not in param_dict.keys():
param_dict[ACTIVATION_QUANTIZATION] = {
SHARED_PARAMETERS: {},
DIFFERENT_GROUPS: {}
}
sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
# shared parameters
output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(
sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(
sub_param_dict)
return output
def get_activation_quantization_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZATION_ENABLED,
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_TYPE,
ACTIVATION_QUANTIZE_TYPE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_RANGE,
ACTIVATION_QUANTIZE_RANGE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
else:
output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
output[
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
return output
def get_activation_quantization_different_groups(param_dict):
output = {}
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
def get_sparse_pruning(param_dict):
output = {}
if SPARSE_PRUNING not in param_dict.keys():
param_dict[SPARSE_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[SPARSE_PRUNING]
# shared parameters
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output
def get_sparse_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(
sub_param_dict,
SPARSE_PRUNING_ENABLED,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT
return output
def get_sparse_pruning_different_groups(param_dict):
output = {}
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
def get_row_pruning(param_dict):
output = {}
if ROW_PRUNING not in param_dict.keys():
param_dict[ROW_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[ROW_PRUNING]
# shared parameters
output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
return output
def get_row_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
ROW_PRUNING_ENABLED,
ROW_PRUNING_ENABLED_DEFAULT)
output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
ROW_PRUNING_METHOD,
ROW_PRUNING_METHOD_DEFAULT)
assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
ROW_PRUNING_SCHEDULE_OFFSET,
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
output[ROW_PRUNING_SCHEDULE_OFFSET] = ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT
return output
def get_row_pruning_different_groups(param_dict):
output = {}
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
def get_head_pruning(param_dict):
output = {}
if HEAD_PRUNING not in param_dict.keys():
param_dict[HEAD_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[HEAD_PRUNING]
# shared parameters
output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
return output
def get_head_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
HEAD_PRUNING_ENABLED,
HEAD_PRUNING_ENABLED_DEFAULT)
output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
HEAD_PRUNING_METHOD,
HEAD_PRUNING_METHOD_DEFAULT)
assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
HEAD_PRUNING_SCHEDULE_OFFSET,
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
if output[HEAD_PRUNING_ENABLED]:
assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
else:
output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
output[HEAD_PRUNING_METHOD] = HEAD_PRUNING_METHOD_DEFAULT
output[HEAD_PRUNING_SCHEDULE_OFFSET] = HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT
return output
def get_head_pruning_different_groups(param_dict):
output = {}
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
def get_channel_pruning(param_dict):
output = {}
if CHANNEL_PRUNING not in param_dict.keys():
param_dict[CHANNEL_PRUNING] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[CHANNEL_PRUNING]
# shared parameters
output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
return output
def get_channel_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_ENABLED,
CHANNEL_PRUNING_ENABLED_DEFAULT)
output[CHANNEL_PRUNING_METHOD] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_METHOD,
CHANNEL_PRUNING_METHOD_DEFAULT)
assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_SCHEDULE_OFFSET,
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT
return output
def get_channel_pruning_different_groups(param_dict):
output = {}
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment