Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import re
import collections.abc
......@@ -176,6 +179,7 @@ def fetch_hostfile(hostfile_path):
def validate_ds_config(config: dict):
def is_False(config: dict, key):
if config is None:
return False
......@@ -189,9 +193,7 @@ def validate_ds_config(config: dict):
if stage == 1:
return True
elif stage == 2:
if is_False(config_zero,
"cpu_offload") and is_False(config_zero,
"cpu_offload_params"):
if is_False(config_zero, "cpu_offload") and is_False(config_zero, "cpu_offload_params"):
return False
elif stage == 3:
offload_devices = ["cpu", "nvme"]
......@@ -289,14 +291,13 @@ def get_all_configs(tuning_space: dict, ignore_keys=None):
Args:
tuning_space (dict): the tuning space where tunable parameters are lists of values.
"""
def gen_combinations(d: dict):
keys, values = d.keys(), d.values()
for v in values:
if not isinstance(v, list):
v = [v]
values_choices = (gen_combinations(v) if isinstance(v,
dict) else get_list(v)
for v in values)
values_choices = (gen_combinations(v) if isinstance(v, dict) else get_list(v) for v in values)
for comb in itertools.product(*values_choices):
yield dict(zip(keys, comb))
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .reshape_meg_2d import reshape_meg_2d_parallel
from .deepspeed_checkpoint import DeepSpeedCheckpoint
from .utils import (get_layer_ckpt_name_for_rank,
get_model_ckpt_name_for_rank,
get_zero_ckpt_name_for_rank)
from .utils import (get_layer_ckpt_name_for_rank, get_model_ckpt_name_for_rank, get_zero_ckpt_name_for_rank)
from .reshape_utils import (merge_state)
......
'''Copyright The Microsoft DeepSpeed Team'''
'''
Various symbolic constants used for model checkpointing
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Various symbolic constants used for model checkpointing
"""
#########################################
# Optimizer checkpoint keys
......@@ -24,6 +27,8 @@ FP32_WEIGHT_KEY = "fp32"
PARAM = 'param'
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'
FROZEN_PARAM_SHAPES = 'frozen_param_shapes'
FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments'
#########################################
# Checkpoint naming constants
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from typing import Dict
import torch
from .reshape_3d_utils import model_3d_desc
from .reshape_utils import (basic_folder_validation,
merge_state,
partition_data,
get_files,
get_files_with_prefix)
from .reshape_utils import (basic_folder_validation, merge_state, partition_data, get_files, get_files_with_prefix)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
......@@ -24,19 +23,15 @@ CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'
SEQUENTIAL_LAYERS = [
'input_layernorm.weight',
'input_layernorm.bias',
'self_attention.dense.bias',
'post_attention_layernorm.weight',
'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
'post_attention_layernorm.bias', 'mlp.dense_4h_to_h.bias', 'position_embeddings.weight'
]
LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1}
class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)
......@@ -50,33 +45,24 @@ class DeepSpeedCheckpoint(object):
self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.tp_degree = self.zero_checkpoint.get_src_tp_degree(
) if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree(
) if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree(
) if dp_degree is None else dp_degree
self.tp_degree = self.zero_checkpoint.get_src_tp_degree() if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree() if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree() if dp_degree is None else dp_degree
self.original_world_size = self.zero_checkpoint.get_src_tp_degree(
) * self.zero_checkpoint.get_src_pp_degree(
self.original_world_size = self.zero_checkpoint.get_src_tp_degree() * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_dp_degree()
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree
self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
self.zero_checkpoint.get_src_tp_degree())
self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel(
old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
if self.is_change_pp_degree() or self.is_change_tp_degree(
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
model_3d_desc(self.pp_degree,
self.tp_degree,
self.dp_degree))
self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
if self.is_change_pp_degree() or self.is_change_tp_degree() or self.is_change_dp_degree():
self.zero_checkpoint.reshape(model_3d_desc(self.pp_degree, self.tp_degree, self.dp_degree))
self.global_state = {}
......@@ -84,8 +70,7 @@ class DeepSpeedCheckpoint(object):
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(
FINAL_LAYER_NORM_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self._build_global_state()
def is_change_tp_degree(self):
......@@ -131,9 +116,7 @@ class DeepSpeedCheckpoint(object):
keys_to_ignore=[PARAM_SHAPES])
def get_zero_files(self, pp_index, tp_index, dp_index) -> list:
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index,
tp_index=tp_index,
dp_index=dp_index)
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]
......@@ -150,11 +133,7 @@ class DeepSpeedCheckpoint(object):
def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [
torch.load(fname,
map_location=torch.device('cpu'))
for fname in self.tp_to_embedding_map[tp_index]
]
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd = self._merge_state_dicts(sd_list)
return sd
......@@ -179,10 +158,7 @@ class DeepSpeedCheckpoint(object):
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
merged_sd = None
for sd in sd_list:
......@@ -198,10 +174,7 @@ class DeepSpeedCheckpoint(object):
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
......@@ -212,8 +185,7 @@ class DeepSpeedCheckpoint(object):
def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0],
map_location=torch.device('cpu'))
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
return sd
def get_final_norm_files(self, tp_index: int) -> list:
......@@ -222,8 +194,7 @@ class DeepSpeedCheckpoint(object):
def _build_tp_other_layer_map(self, layer_index: int):
assert layer_index < len(self.layer_files)
layer_files = get_files_with_prefix(self.layer_files,
self.layer_keys[layer_index])
layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
layer_file_partitions = partition_data(layer_files, self.tp_degree)
data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
return data_map
......@@ -238,11 +209,7 @@ class DeepSpeedCheckpoint(object):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
for i in range(0,
self.pp_degree)
}
data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)}
return data_map
def _dump_mapping(self, data_map, map_tag=None):
......@@ -308,10 +275,8 @@ class DeepSpeedCheckpoint(object):
file_list = get_files(dir)
for file_prefix in [
MODEL_FILE_PREFIX,
LAYER_FILE_PREFIX,
f'{LAYER_FILE_PREFIX}01'
]:
for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']:
ckpt_files = get_files_with_prefix(file_list, file_prefix)
assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'
assert len(
ckpt_files
) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
from .reshape_utils import (get_files,
get_files_with_prefix,
partition_data,
get_zero_files)
# DeepSpeed Team
from .reshape_utils import (get_files, get_files_with_prefix, partition_data, get_zero_files)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
......@@ -15,6 +15,7 @@ DP_DIM = 'DP'
class model_3d_desc(object):
def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1):
self.pp_degree = pp_degree
self.tp_degree = tp_degree
......@@ -33,8 +34,7 @@ class model_3d_desc(object):
src_2d_size=self.pp_degree * self.tp_degree,
dp_degree=self.dp_degree)
return unflatten_dp_dimension(meg_2d_map=flat_3d_map,
dp_degree=target_3d_desc.dp_degree)
return unflatten_dp_dimension(meg_2d_map=flat_3d_map, dp_degree=target_3d_desc.dp_degree)
def get_desc(self):
return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'
......@@ -45,14 +45,11 @@ class model_3d_desc(object):
def is_valid(self, pp_index, tp_index, dp_index):
err_msg = []
valid = True
for index, degree, dim_name in [
(pp_index, self.pp_degree, PP_DIM),
(tp_index, self.tp_degree, TP_DIM),
(dp_index, self.dp_degree, DP_DIM)]:
for index, degree, dim_name in [(pp_index, self.pp_degree, PP_DIM), (tp_index, self.tp_degree, TP_DIM),
(dp_index, self.dp_degree, DP_DIM)]:
if index >= degree:
valid = False
err_msg.append(
f'{dim_name} indexing error: index {index} >= degree {degree}')
err_msg.append(f'{dim_name} indexing error: index {index} >= degree {degree}')
return valid, err_msg
......@@ -60,18 +57,15 @@ class model_3d_desc(object):
err_msg = []
if target_3d_desc.pp_degree > self.pp_degree:
err_msg.append(
f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}'
)
f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}')
if target_3d_desc.tp_degree > self.tp_degree:
err_msg.append(
f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}'
)
f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}')
if target_3d_desc.dp_degree > self.dp_degree:
err_msg.append(
f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}'
)
f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}')
return len(err_msg) == 0, err_msg
......@@ -106,10 +100,7 @@ def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree):
def unflatten_dp_dimension(meg_2d_map, dp_degree):
pp_degree = meg_2d_map.pp_degree
tp_degree = meg_2d_map.tp_degree
meg_2d_map_list = [
meg_2d_parallel_map(pp_degree=pp_degree,
tp_degree=tp_degree) for _ in range(dp_degree)
]
meg_2d_map_list = [meg_2d_parallel_map(pp_degree=pp_degree, tp_degree=tp_degree) for _ in range(dp_degree)]
for pp_index in range(pp_degree):
for tp_index in range(tp_degree):
flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .reshape_utils import partition_data
class meg_2d_parallel_map(object):
def __init__(self, pp_degree, tp_degree):
self.pp_degree = pp_degree
self.tp_degree = tp_degree
......@@ -11,8 +15,7 @@ class meg_2d_parallel_map(object):
def simple_init(self):
self.map = {
self._make_key(i // self.tp_degree,
i % self.tp_degree): [i]
self._make_key(i // self.tp_degree, i % self.tp_degree): [i]
for i in range(self.pp_degree * self.tp_degree)
}
......@@ -74,11 +77,7 @@ def _reshape_pp_dimension(old_2d_map, new_pp_degree):
return new_2d_map
def reshape_meg_2d_parallel(old_pp_degree,
old_tp_degree,
new_pp_degree,
new_tp_degree,
verbose=False):
def reshape_meg_2d_parallel(old_pp_degree, old_tp_degree, new_pp_degree, new_tp_degree, verbose=False):
assert new_pp_degree <= old_pp_degree
assert new_tp_degree <= old_tp_degree
......@@ -137,8 +136,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
tensor_model_parallel_size = min(tp_size, world_size)
pipeline_model_parallel_size = min(pp_size, world_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
......@@ -158,10 +156,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the model-parallel groups.
all_pp_group_ranks = []
for i in range(data_parallel_size):
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_dp_group_ranks
]
ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_dp_group_ranks]
all_pp_group_ranks.append(list(ranks))
print(f"PP", all_pp_group_ranks)
......@@ -169,8 +164,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the tensor model-parallel groups.
all_tp_group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
all_tp_group_ranks.append(list(ranks))
print(f"TP", all_tp_group_ranks)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
......@@ -49,11 +52,7 @@ def partition_data(data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
partition_size = num_elems // num_partitions
partitions_list = [
data_list[i:i + partition_size] for i in range(0,
num_elems,
partition_size)
]
partitions_list = [data_list[i:i + partition_size] for i in range(0, num_elems, partition_size)]
return partitions_list
......@@ -76,9 +75,7 @@ def merge_state_dict(dict_a, dict_b, key_list):
def merge_state_list(list_a, list_b, key_list):
if len(list_a) != len(list_b):
print(f'{_key_list_to_string(key_list)}')
raise ValueError(
f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}'
)
raise ValueError(f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}')
return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)]
......@@ -87,8 +84,7 @@ def merge_state(state_a, state_b, key_list=[]):
if type(state_a) != type(state_b):
key_list_string = _key_list_to_string(key_list)
print(f'key_list = {key_list_string}')
raise ValueError(
f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
raise ValueError(f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
if type(state_a) in (dict, OrderedDict):
return merge_state_dict(state_a, state_b, key_list)
......
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY,
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
CAT_DIM)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM)
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
......@@ -44,9 +43,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(
VOCAB_DIVISIBILITY_PADDING_TENSOR,
None)
vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None)
if vocab_divisibility_padding_tensor is not None:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
......@@ -56,13 +53,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_size),
"constant",
0)
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
......@@ -76,8 +67,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(
key)
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
......@@ -90,9 +80,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
tp_hp_slice = tp_hp_slice.flatten()
lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
......@@ -104,5 +92,4 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
def enable_universal_checkpoint(param_list):
for param in param_list:
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state,
param)
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from .constants import (MODEL_FILE_PREFIX,
MODEL_FILE_SUFFIX,
OPTIM_FILE_SUFFIX,
ZERO_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX)
def get_model_ckpt_name_for_rank(base_folder, mp_rank_str):
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .constants import (BASE_OPTIMIZER_STATE,
GROUP_PADDINGS,
OPTIMIZER_STATE_DICT,
PARTITION_COUNT)
from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, PARTITION_COUNT)
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
......@@ -15,6 +15,7 @@ GROUP_STATE_KEY = 'state'
class ZeROCheckpoint(object):
def __init__(self, dir):
basic_folder_validation(dir)
self.dir = dir
......@@ -49,12 +50,7 @@ class ZeROCheckpoint(object):
file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
return [self.file_list[idx] for idx in file_idx_list]
def get_state_for_rank(self,
pp_index,
tp_index,
dp_index,
keys_to_ignore=[],
strip_tensor_paddings=True):
def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], strip_tensor_paddings=True):
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
......@@ -111,10 +107,7 @@ class ZeROCheckpoint(object):
for state_name, state_value in group_state.items():
if torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value,
0,
0,
raw_length).clone()
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
......@@ -144,5 +137,4 @@ class ZeROCheckpoint(object):
partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
if partition_counts:
num_groups = len(partition_counts)
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree
] * num_groups
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree] * num_groups
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
import torch
from .utils import *
from deepspeed import utils
supported_torch_version = False
# See more details at: https://github.com/pytorch/pytorch/pull/48767
# The PG API in torch versions lesser than 1.8 are different so it is
# non-trivial to support both in the same API. We will just use the
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
if older_torch():
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
supported_torch_version = False
from torch.distributed import *
def get_world_group():
return group.WORLD
# DeepSpeed Team
def get_global_rank(group, group_rank):
if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
else:
from torch.distributed.distributed_c10d import _get_global_rank
return _get_global_rank(group, group_rank)
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import all_gather, get_world_size
from torch import chunk
output_tensors = list(chunk(output_tensor, get_world_size(group)))
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)
def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import reduce_scatter, get_world_size
from torch import chunk
input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
return reduce_scatter(output_tensor, input_tensor_lst, group=group)
def configure(deepspeed_config=None,
enabled=None,
prof_all=None,
prof_ops=None,
verbose=None):
utils.logger.warn(
"Communication logging is not supported in torch versions older than 1.8")
else:
supported_torch_version = True
from .comm import *
from .utils import *
from .comm import *
'''Copyright The Microsoft DeepSpeed Team'''
''' DeepSpeed Communication Backend.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed Communication Backend.
In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distributed. Simply wrap torch.distributed for now.
......@@ -8,18 +12,18 @@ In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distribute
- MPI -- [EXPERIMENTAL]
- RCCL -- [EXPERIMENTAL]
- GLOO -- [EXPERIMENTAL]
# DS backend wrapper for torch.distributed [DEFAULT]
- T-NCCL -- [DEFAULT]
- T-GLOO
- T-MPI
'''
''' Backend is the base class
# Backend is the base class
-- NcclBackend, MpiBackend, and TorchBackend are the main subclasses. TorchBackend is the only officially supported backend for now.
'''
"""
class Backend(object):
def __init__(self, name='backend', rank=0, size=1):
self.name = name
# The world size and rank of the world process group
......
"""
Copyright 2021 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed Communication Package: deepspeed.comm
deepspeed.comm
-- import and use deepspeeed.ops.comm
-- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
-- use custom collectives
-- can either use torch.dist or ds.ops.comm?
Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
deepspeed.comm API
-- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
-- e.g. if a client code used
from deepspeed import comm as dist
instead of
import torch.distributed as dist
The code should work without breaking any of the public torch.distributed functionality
Future:
-- deepspeed groups API should be brought into ds.comm
"""
......@@ -63,10 +59,6 @@ timer_summary = {}
comms_logger = CommsLogger()
# Ensure we don't warn about base collectives more than once
has_warned_all_gather = False
has_warned_reduce_scatter = False
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend = None
mpi_backend = None
......@@ -110,12 +102,13 @@ def configure(
# Logging wrapper for timing ops
def timed_op(func):
def log_wrapper(*args, **kwargs):
# Add enabled flag so that overhead to each comm op is two if conditions at most
if comms_logger.enabled:
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
if ('prof' in kwargs
and kwargs['prof']) or comms_logger.prof_all or ('log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
# Need func args for their defaults
func_args = get_default_args(func)
func_args.update(kwargs)
......@@ -133,8 +126,7 @@ def timed_op(func):
if cdb.using_mpi:
cdb.barrier()
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
'log_name' in kwargs and kwargs['log_name'] in comms_logger.prof_ops):
log_name = get_debug_log_name(func_args, comms_logger.debug)
raw_name = func.__name__
timers(log_name).stop()
......@@ -182,7 +174,8 @@ def destroy_process_group(group=None):
def new_group(ranks):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.new_group(ranks)
......@@ -196,14 +189,12 @@ def is_available() -> bool:
return True
def set_backend(backend):
def set_backend(backend_name):
if not use_ds_backend:
utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
)
raise RuntimeError(
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.')
global cdb
global nccl_backend
......@@ -221,13 +212,7 @@ def set_backend(backend):
@timed_op
def broadcast(tensor,
src,
group=None,
async_op=False,
prof=False,
log_name='broadcast',
debug=get_caller_func()):
def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='broadcast', debug=get_caller_func()):
global cdb
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
......@@ -241,17 +226,14 @@ def all_gather(tensor_list,
log_name='all_gather',
debug=get_caller_func()):
global cdb
return cdb.all_gather(tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)
return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def has_reduce_scatter_base():
def has_reduce_scatter_tensor():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
return cdb.has_reduce_scatter_base
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.has_reduce_scatter_tensor()
def reduce_scatter_fn(output_tensor,
......@@ -262,23 +244,21 @@ def reduce_scatter_fn(output_tensor,
prof=False,
debug=get_caller_func()):
global cdb
global has_warned_reduce_scatter
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_base:
return reduce_scatter_base(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_tensor():
return reduce_scatter_tensor(output_tensor,
tensor,
op=op,
group=group,
async_op=async_op,
prof=prof,
debug=debug)
else:
if not has_warned_reduce_scatter:
utils.logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
has_warned_reduce_scatter = True
if get_rank() == 0:
utils.logger.warning_once("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor,
input_tensor_lst,
......@@ -290,71 +270,54 @@ def reduce_scatter_fn(output_tensor,
@timed_op
def reduce_scatter_base(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter_base',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter_base(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
def reduce_scatter_tensor(output_tensor,
tensor,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='reduce_scatter_tensor',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter_tensor(output_tensor=output_tensor,
input_tensor=tensor,
op=op,
group=group,
async_op=async_op)
@timed_op
def all_gather_base(output_tensor,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather_base',
debug=get_caller_func()):
global cdb
return cdb.all_gather_base(output_tensor=output_tensor,
input_tensor=tensor,
group=group,
async_op=async_op)
def has_allgather_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
return cdb.has_allgather_base
def allgather_fn(output_tensor,
input_tensor,
group=None,
async_op=False,
debug=get_caller_func()):
global cdb
global has_warned_all_gather
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_allgather_base:
return all_gather_base(output_tensor,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
def all_gather_into_tensor(output_tensor,
tensor,
group=None,
async_op=False,
prof=False,
log_name='all_gather_into_tensor',
debug=get_caller_func()):
global cdb
return cdb.all_gather_into_tensor(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
def has_all_gather_into_tensor():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.has_all_gather_into_tensor()
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_all_gather_into_tensor():
return all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
else:
if not has_warned_all_gather and get_rank() == 0:
utils.logger.warning(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
has_warned_all_gather = True
if get_rank() == 0:
utils.logger.warning_once("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
return all_gather(output_tensors,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op, debug=debug)
@timed_op
......@@ -377,49 +340,25 @@ def all_to_all_single(output,
@timed_op
def send(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='send',
debug=get_caller_func()):
def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def recv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='recv',
debug=get_caller_func()):
def recv(tensor, src=None, group=None, tag=0, prof=False, log_name='recv', debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op
def isend(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='isend',
debug=get_caller_func()):
def isend(tensor, dst, group=None, tag=0, prof=False, log_name='isend', debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def irecv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='irecv',
debug=get_caller_func()):
def irecv(tensor, src=None, group=None, tag=0, prof=False, log_name='irecv', debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
......@@ -434,11 +373,7 @@ def gather(tensor,
log_name='gather',
debug=get_caller_func()):
global cdb
return cdb.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)
return cdb.gather(tensor=tensor, gather_list=gather_list, dst=dst, group=group, async_op=async_op)
@timed_op
......@@ -451,20 +386,11 @@ def scatter(tensor,
log_name='scatter',
debug=get_caller_func()):
global cdb
return cdb.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)
return cdb.scatter(tensor=tensor, scatter_list=scatter_list, src=src, group=group, async_op=async_op)
@timed_op
def barrier(group=None,
async_op=False,
device_ids=None,
prof=False,
log_name='barrier',
debug=get_caller_func()):
def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
......@@ -511,11 +437,31 @@ def reduce_scatter(output,
log_name='reduce_scatter',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter(output=output,
input_list=input_list,
op=op,
group=group,
async_op=async_op)
return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
def has_all_reduce_coalesced():
""""""
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_all_reduce_coalesced is not None, 'has_all_reduce_coalesced is not yet defined'
return cdb.has_all_reduce_coalesced
def has_coalescing_manager():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_coalescing_manager is not None, 'has_coalescing_manager is not yet defined'
return cdb.has_coalescing_manager
def all_gather_coalesced(output_tensors, input_tensors, group=None, async_op=False):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.all_gather_coalesced(output_tensors, input_tensors, group=group, async_op=async_op)
@timed_op
......@@ -535,9 +481,22 @@ def all_reduce(tensor,
return cdb.all_reduce(tensor, op, group, async_op)
@timed_op
def all_reduce_coalesced(tensors,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='all_reduce',
debug=get_caller_func()):
global cbd
return cdb.all_reduce_coalesced(tensors, op, group, async_op)
def get_world_group():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_group()
......@@ -553,7 +512,8 @@ def get_world_size(group=None) -> int:
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_size(group)
......@@ -572,7 +532,8 @@ def get_rank(group=None):
-1, if not part of the group
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_rank(group)
......@@ -585,13 +546,15 @@ def get_local_rank():
local rank (= GPU device ID)
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return get_local_rank_from_launcher()
def get_global_rank(group=None, group_rank=0):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_global_rank(group, group_rank)
......@@ -640,9 +603,7 @@ def init_distributed(dist_backend=None,
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
utils.logger.info(
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose)
elif in_aws_sm():
......@@ -658,9 +619,7 @@ def init_distributed(dist_backend=None,
if dist_backend == None:
dist_backend = get_accelerator().communication_backend_name()
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info(
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
utils.logger.info('Initializing TorchBackend in DeepSpeed with backend {}'.format(dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
......@@ -695,16 +654,12 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True)
if verbose:
utils.logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}".
format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if cdb is not None and cdb.is_initialized():
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
rank, cdb.get_rank())
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, cdb.get_rank())
assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, cdb.get_world_size())
......@@ -731,8 +686,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
......@@ -745,8 +699,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
if verbose:
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
......@@ -754,10 +707,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
if verbose:
utils.logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
.format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
......@@ -771,8 +721,5 @@ def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
if verbose:
utils.logger.info(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
.format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
'''Copyright The Microsoft DeepSpeed Team'''
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from pydantic import BaseModel
from .constants import *
class CommsConfig(BaseModel):
class Config:
validate_all = True
validate_assignment = True
......@@ -25,6 +25,7 @@ class CommsLoggerConfig(CommsConfig):
class DeepSpeedCommsConfig:
def __init__(self, ds_config):
self.comms_logger_enabled = 'comms_logger' in ds_config
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
NCCL_BACKEND = 'nccl'
MPI_BACKEND = 'mpi'
......
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed import utils
from .utils import *
from .backend import *
from .comm import *
import torch
def is_torch_two():
TORCH_MAJOR = int(torch.__version__.split('.')[0])
if TORCH_MAJOR >= 2:
return True
else:
return False
def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
def has_all_reduce_coalesced():
return hasattr(torch.distributed, "all_reduce_coalesced")
def get_coalescing_manager(group, device, reqs):
if is_torch_two():
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
else:
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
class TorchBackend(Backend):
"""
......@@ -16,17 +44,13 @@ class TorchBackend(Backend):
so no need to wrap all the functions. We can keep adding wrappers as
needed.
"""
def __init__(self,
backend,
timeout,
init_method,
rank=-1,
world_size=-1,
name='torch'):
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.has_allgather_base = has_allgather_base()
self.has_reduce_scatter_base = has_reduce_scatter_base()
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_coalescing_manager = has_coalescing_manager()
self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True
self.name = name
# Future functionality to support ds.initialize() on a single GPU
......@@ -35,6 +59,28 @@ class TorchBackend(Backend):
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size)
@classmethod
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
elif hasattr(torch.distributed, "_all_gather_base"):
return torch.distributed._all_gather_base
return None
@classmethod
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
elif hasattr(torch.distributed, "_reduce_scatter_base"):
return torch.distributed._reduce_scatter_base
return None
def has_all_gather_into_tensor(self):
return self.all_gather_function is not None
def has_reduce_scatter_tensor(self):
return self.reduce_scatter_function is not None
def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend,
......@@ -44,30 +90,24 @@ class TorchBackend(Backend):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'
def all_reduce(self,
tensor,
op=torch.distributed.ReduceOp.SUM,
group=None,
async_op=False):
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor,
op=op,
group=group,
async_op=async_op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
"""
if not self.has_all_reduce_coalesced:
raise RuntimeError(f"Current torch version does not have all_reduce_coalesced "
f"api (torch.__version__: {torch.__version__})")
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce(tensor=tensor,
dst=dst,
op=self._reduce_op(op),
group=group,
async_op=async_op)
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
def reduce_scatter(self,
output,
input_list,
op=ReduceOp.SUM,
group=None,
async_op=False):
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce_scatter(output=output,
input_list=input_list,
op=self._reduce_op(op),
......@@ -75,48 +115,57 @@ class TorchBackend(Backend):
async_op=async_op)
def broadcast(self, tensor, src, group=None, async_op=False):
return torch.distributed.broadcast(tensor=tensor,
src=src,
group=group,
async_op=async_op)
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return torch.distributed.all_gather(tensor_list=tensor_list,
tensor=tensor,
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_allgather_base:
return torch.distributed.distributed_c10d._all_gather_base(
output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
else:
utils.logger.warning(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def reduce_scatter_base(self,
output_tensor,
input_tensor,
op=ReduceOp.SUM,
group=None,
async_op=False):
if self.has_reduce_scatter_base:
return torch.distributed._reduce_scatter_base(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
""""""
assert len(output_tensors) == len(input_tensors), ""
if hasattr(torch.distributed.distributed_c10d, '_all_gather_base_coalesced'):
# customized PyTorch
return torch.distributed.distributed_c10d._all_gather_base_coalesced(output_tensors,
input_tensors,
group=group,
async_op=async_op)
elif has_coalescing_manager():
reqs = []
with get_coalescing_manager(group, input_tensors[0].device, reqs):
for output, input in zip(output_tensors, input_tensors):
handle = torch.distributed.distributed_c10d.all_gather_into_tensor(output,
input,
group=group,
async_op=True)
reqs.append(handle)
if async_op:
return reqs[-1]
else:
reqs[-1].wait()
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
else:
utils.logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
utils.logger.warning("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def all_to_all_single(self,
......@@ -159,25 +208,15 @@ class TorchBackend(Backend):
group=group,
async_op=async_op)
def barrier(self,
group=torch.distributed.GroupMember.WORLD,
async_op=False,
device_ids=None):
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group,
async_op=async_op,
device_ids=device_ids)
def monitored_barrier(self,
group=torch.distributed.GroupMember.WORLD,
timeout=None,
wait_all_ranks=False):
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.monitored_barrier(group=group,
timeout=timeout,
wait_all_ranks=wait_all_ranks)
return torch.distributed.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
def get_rank(self, group=None):
return torch.distributed.get_rank(group=group)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
import inspect
from deepspeed.utils import get_caller_func
def older_torch():
'''
Helper to lookup torch version. For versions less than 1.8, torch.dist
used torch.distributed.group.WORLD as the default group argument instead of None.
See more details at: https://github.com/pytorch/pytorch/pull/48767
'''
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
return True
else:
return False
def has_allgather_base():
'''
Helper to check if torch.distributed has _all_gather_base
'''
return hasattr(torch.distributed, "_all_gather_base")
def has_reduce_scatter_base():
'''
Helper to check if torch.distributed has _reduce_scatter_base
'''
return hasattr(torch.distributed, "_reduce_scatter_base")
def get_local_rank_from_launcher():
# DeepSpeed launcher will set it so get from there
......@@ -84,11 +58,7 @@ def get_world_size_from_launcher():
def get_default_args(func):
signature = inspect.signature(func)
return {
k: v.default
for k,
v in signature.parameters.items() if v.default is not inspect.Parameter.empty
}
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
# We need this hacky function since torch doesn't consistently name or place the input tensor args
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .compress import init_compression, redundancy_clean
from .scheduler import compression_scheduler
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import math
......@@ -21,6 +24,7 @@ class QuantAct(nn.Module):
Momentum for updating the activation quantization range.
quant_mode : str, default 'symmetric'
"""
def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
super(QuantAct, self).__init__()
......@@ -50,10 +54,8 @@ class QuantAct(nn.Module):
self.x_min_max[1] = x_max
# if do not need momentum, please set self.act_range_momentum = 0
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (
1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (
1 - self.act_range_momentum)
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
......@@ -61,6 +63,7 @@ class QuantAct(nn.Module):
class Embedding_Compress(nn.Embedding):
def __init__(self, *kargs):
super(Embedding_Compress, self).__init__(*kargs)
self.weight.start_bits = None
......@@ -71,17 +74,10 @@ class Embedding_Compress(nn.Embedding):
def extra_repr(self):
return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
self.num_embeddings,
self.embedding_dim,
self.weight.target_bits)
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.num_embeddings, self.embedding_dim, self.weight.target_bits)
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
......@@ -105,31 +101,20 @@ class Embedding_Compress(nn.Embedding):
self.weight_quantize_num_groups = self.weight.size(0)
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
else:
weight = self.weight
out = nn.functional.embedding(input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse)
out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
return out
......@@ -137,6 +122,7 @@ class LinearLayer_Compress(nn.Linear):
"""
Linear layer with compression.
"""
def __init__(self, *kargs, bias=True):
super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
self.sparse_pruning_method = None
......@@ -169,8 +155,7 @@ class LinearLayer_Compress(nn.Linear):
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
self.weight.device)
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
......@@ -209,11 +194,9 @@ class LinearLayer_Compress(nn.Linear):
raise NotImplementedError
else:
self.head_pruning_ratio = ratio
self.head_pruning_scores = nn.Parameter(torch.Tensor(
1,
self.num_heads)) # we apply the pruning to O matrix
self.head_pruning_scores.data = self.head_pruning_scores.data.to(
self.weight.device)
self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
self.num_heads)) # we apply the pruning to O matrix
self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
def fix_sparse_pruning_helper(self):
......@@ -279,18 +262,17 @@ class LinearLayer_Compress(nn.Linear):
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape).t())
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
-1)[mask.view(-1), :].reshape(-1,
shape).t())
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
else:
shape = self.weight.size()
self.weight.data = (self.weight.data.t().reshape(self.num_heads,
-1) *
mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
shape[1], shape[0]).t()
if self.head_pruning_method == 'topk':
del self.head_pruning_scores
......@@ -316,37 +298,26 @@ class LinearLayer_Compress(nn.Linear):
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores,
self.sparse_pruning_ratio,
False)
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
else:
raise NotImplementedError
if pruning_type == 'row':
if self.row_pruning_method == 'l1':
return self.row_pruning_mask.to(self.weight.device)
elif self.row_pruning_method == 'topk':
return TopKBinarizer.apply(self.row_mask_scores,
self.row_pruning_ratio,
False)
return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
else:
raise NotImplementedError
elif pruning_type == 'head':
if self.head_pruning_method == 'topk':
return TopKBinarizer.apply(self.head_pruning_scores,
self.head_pruning_ratio,
False)
return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
......@@ -369,10 +340,7 @@ class LinearLayer_Compress(nn.Linear):
self.weight_quantize_num_groups = num_groups
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
......@@ -391,18 +359,12 @@ class LinearLayer_Compress(nn.Linear):
def head_pruning_reshape(self, w, mask):
shape = w.shape
return (w.t().reshape(self.num_heads,
-1) * mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()
def forward(self, input, skip_bias_add=False):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
bias = self.bias
else:
......@@ -428,11 +390,7 @@ class LinearLayer_Compress(nn.Linear):
num_groups = input.numel() // input.size(-1)
else:
num_groups = 1
input = self.activation_quantizer(input,
self.activation_quantization_bits,
None,
None,
num_groups)
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
if skip_bias_add:
# used for mpu linear layers
......@@ -447,6 +405,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
"""
Conv2D layer with compression.
"""
def __init__(self, *kargs):
super(Conv2dLayer_Compress, self).__init__(*kargs)
self.sparse_pruning_method = None
......@@ -478,10 +437,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
output = s.format(**self.__dict__)
return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
self.sparse_pruning_method is not None,
self.channel_pruning_method is not None,
self.activation_quantization_method is not None,
self.weight.target_bits)
self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
self.activation_quantization_method is not None, self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method):
self.sparse_pruning_ratio = ratio
......@@ -493,8 +450,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
self.weight.device)
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
......@@ -514,13 +470,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask = mask.view(-1, 1, 1, 1)
mask = mask.to(self.weight.device)
elif method == 'topk':
self.channel_mask_scores = nn.Parameter(
torch.Tensor(self.weight.size(0),
1,
1,
1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(
self.weight.device)
self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
mask = None
else:
......@@ -579,39 +530,27 @@ class Conv2dLayer_Compress(nn.Conv2d):
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores,
self.sparse_pruning_ratio,
False)
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
else:
raise NotImplementedError
elif pruning_type == 'channel':
if self.channel_pruning_method == 'l1':
return self.channel_pruning_mask.to(self.weight.device)
elif self.channel_pruning_method == 'topk':
return TopKBinarizer.apply(self.channel_mask_scores,
self.channel_pruning_ratio,
False)
return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
......@@ -642,10 +581,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
bias = self.bias
else:
......@@ -667,22 +603,13 @@ class Conv2dLayer_Compress(nn.Conv2d):
num_groups = input.numel() // input[0].numel()
else:
num_groups = 1
input = self.activation_quantizer(input,
self.activation_quantization_bits,
None,
None,
num_groups)
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
return nn.functional.conv2d(input,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups)
return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class BNLayer_Compress(nn.BatchNorm2d):
def fix_channel_pruning_helper(self, mask, dim_reduction=True):
self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
......@@ -770,6 +697,7 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_):
return input_
......@@ -781,6 +709,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
......@@ -792,6 +721,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_):
return _split(input_)
......@@ -803,6 +733,7 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
......@@ -834,13 +765,8 @@ def gather_from_model_parallel_region(input_):
class ColumnParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu,
input_size,
output_size,
bias=True,
gather_output=True,
skip_bias_add=False):
def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
......@@ -854,10 +780,7 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
assert output_size % world_size == 0
self.output_size_per_partition = output_size // world_size
super(ColumnParallelLinear_Compress,
self).__init__(self.input_size,
self.output_size_per_partition,
bias=bias)
super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
......@@ -877,13 +800,8 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
class RowParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu,
input_size,
output_size,
bias=True,
input_is_parallel=False,
skip_bias_add=False):
def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
......@@ -897,10 +815,7 @@ class RowParallelLinear_Compress(LinearLayer_Compress):
assert input_size % world_size == 0
self.input_size_per_partition = input_size // world_size
super(RowParallelLinear_Compress,
self).__init__(self.input_size_per_partition,
self.output_size,
bias=bias)
super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import re
from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible
......@@ -13,21 +16,13 @@ def check_deepspeed_config(config):
if isinstance(config, dict):
return config
elif os.path.exists(config):
return json.load(open(config,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
return json.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
)
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}")
def get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=None,
verbose=True):
def get_module_name(group_name, model, key_word, exist_module_name, mpu=None, verbose=True):
'''
get the associated module name from the model based on the key_word provided by users
'''
......@@ -40,8 +35,7 @@ def get_module_name(group_name,
if name in exist_module_name and verbose:
# logger.warning
raise ValueError(
f"{name} is already added to compression, please check your config file for {group_name}."
)
f"{name} is already added to compression, please check your config file for {group_name}.")
if name not in exist_module_name:
exist_module_name.add(name)
return_module_name.append(name)
......@@ -56,8 +50,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
continue
# for loop different methods, i.e., weight quantization, activation quantization etc
exist_module_name = set()
shared_parameters = method_content[
SHARED_PARAMETERS] # get all the shared parameters
shared_parameters = method_content[SHARED_PARAMETERS] # get all the shared parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
# for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
module_name_list = []
......@@ -65,8 +58,13 @@ def get_compress_methods(model, compress_methods, mpu=None):
if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]:
# this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
# otherwise we just mask those as zeros
for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE], method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE],
method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
module_name, exist_module_name = get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=mpu)
module_name_list.append(module_name)
tmp_related_module_name_list = []
for rkw in related_key_words:
......@@ -76,7 +74,11 @@ def get_compress_methods(model, compress_methods, mpu=None):
related_module_name_list.append(tmp_related_module_name_list)
else:
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
module_name, exist_module_name = get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=mpu)
module_name_list.append(module_name)
if module_name_list:
......@@ -85,13 +87,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
**(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)),
**shared_parameters
}
compression_item = [
module_name_list,
related_module_name_list,
{
method: combined_method_parameters
}
]
compression_item = [module_name_list, related_module_name_list, {method: combined_method_parameters}]
layer_added_compress_methods.append(compression_item)
return layer_added_compress_methods
......@@ -118,9 +114,7 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
assert teacher_model is not None, "Teacher model is required for layer reduction"
student_initialization(c_model, teacher_model, deepspeed_config)
layer_added_compress_methods = get_compress_methods(c_model,
compress_methods,
mpu=mpu)
layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu)
return model
......@@ -143,31 +137,20 @@ def redundancy_clean(model, deepspeed_config, mpu=None):
else:
c_model = model
layer_added_compress_methods_tmp = get_compress_methods(c_model,
compress_methods,
mpu=mpu)
layer_added_compress_methods_tmp = get_compress_methods(c_model, compress_methods, mpu=mpu)
# sort methods
order_list = [
WEIGHT_QUANTIZATION,
SPARSE_PRUNING,
ROW_PRUNING,
HEAD_PRUNING,
CHANNEL_PRUNING,
ACTIVATION_QUANTIZATION
WEIGHT_QUANTIZATION, SPARSE_PRUNING, ROW_PRUNING, HEAD_PRUNING, CHANNEL_PRUNING, ACTIVATION_QUANTIZATION
]
layer_added_compress_methods = sorted(
layer_added_compress_methods_tmp,
key=lambda x: order_list.index(list(x[2].keys())[0]))
layer_added_compress_methods = sorted(layer_added_compress_methods_tmp,
key=lambda x: order_list.index(list(x[2].keys())[0]))
for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods:
stored_mask = []
need_mask = True if related_module_name_lists else False
for i, mnl in enumerate(module_name_lists):
for module_name in mnl:
mask = fix_compression(c_model,
module_name,
compression_technique,
dim_reduction=need_mask)
mask = fix_compression(c_model, module_name, compression_technique, dim_reduction=need_mask)
if need_mask:
stored_mask.append(mask)
if need_mask:
......@@ -219,10 +202,8 @@ def student_initialization(student_model, teacher_model, deepspeed_config):
'''
assert len(student_layer) == len(teacher_layer)
for s_name, t_name in zip(student_layer, teacher_layer):
s_module = recursive_getattr(student_model,
module_name_prefix + '.' + str(s_name))
t_module = recursive_getattr(teacher_model,
module_name_prefix + '.' + str(t_name))
s_module = recursive_getattr(student_model, module_name_prefix + '.' + str(s_name))
t_module = recursive_getattr(teacher_model, module_name_prefix + '.' + str(t_name))
for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
s_param.data.copy_(t_param.data)
for name in other_module_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment