Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import re import re
import collections.abc import collections.abc
...@@ -176,6 +179,7 @@ def fetch_hostfile(hostfile_path): ...@@ -176,6 +179,7 @@ def fetch_hostfile(hostfile_path):
def validate_ds_config(config: dict): def validate_ds_config(config: dict):
def is_False(config: dict, key): def is_False(config: dict, key):
if config is None: if config is None:
return False return False
...@@ -189,9 +193,7 @@ def validate_ds_config(config: dict): ...@@ -189,9 +193,7 @@ def validate_ds_config(config: dict):
if stage == 1: if stage == 1:
return True return True
elif stage == 2: elif stage == 2:
if is_False(config_zero, if is_False(config_zero, "cpu_offload") and is_False(config_zero, "cpu_offload_params"):
"cpu_offload") and is_False(config_zero,
"cpu_offload_params"):
return False return False
elif stage == 3: elif stage == 3:
offload_devices = ["cpu", "nvme"] offload_devices = ["cpu", "nvme"]
...@@ -289,14 +291,13 @@ def get_all_configs(tuning_space: dict, ignore_keys=None): ...@@ -289,14 +291,13 @@ def get_all_configs(tuning_space: dict, ignore_keys=None):
Args: Args:
tuning_space (dict): the tuning space where tunable parameters are lists of values. tuning_space (dict): the tuning space where tunable parameters are lists of values.
""" """
def gen_combinations(d: dict): def gen_combinations(d: dict):
keys, values = d.keys(), d.values() keys, values = d.keys(), d.values()
for v in values: for v in values:
if not isinstance(v, list): if not isinstance(v, list):
v = [v] v = [v]
values_choices = (gen_combinations(v) if isinstance(v, values_choices = (gen_combinations(v) if isinstance(v, dict) else get_list(v) for v in values)
dict) else get_list(v)
for v in values)
for comb in itertools.product(*values_choices): for comb in itertools.product(*values_choices):
yield dict(zip(keys, comb)) yield dict(zip(keys, comb))
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .reshape_meg_2d import reshape_meg_2d_parallel from .reshape_meg_2d import reshape_meg_2d_parallel
from .deepspeed_checkpoint import DeepSpeedCheckpoint from .deepspeed_checkpoint import DeepSpeedCheckpoint
from .utils import (get_layer_ckpt_name_for_rank, from .utils import (get_layer_ckpt_name_for_rank, get_model_ckpt_name_for_rank, get_zero_ckpt_name_for_rank)
get_model_ckpt_name_for_rank,
get_zero_ckpt_name_for_rank)
from .reshape_utils import (merge_state) from .reshape_utils import (merge_state)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
''' # SPDX-License-Identifier: Apache-2.0
Various symbolic constants used for model checkpointing
''' # DeepSpeed Team
"""
Various symbolic constants used for model checkpointing
"""
######################################### #########################################
# Optimizer checkpoint keys # Optimizer checkpoint keys
...@@ -24,6 +27,8 @@ FP32_WEIGHT_KEY = "fp32" ...@@ -24,6 +27,8 @@ FP32_WEIGHT_KEY = "fp32"
PARAM = 'param' PARAM = 'param'
PARAM_SHAPES = 'param_shapes' PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names' BUFFER_NAMES = 'buffer_names'
FROZEN_PARAM_SHAPES = 'frozen_param_shapes'
FROZEN_PARAM_FRAGMENTS = 'frozen_param_fragments'
######################################### #########################################
# Checkpoint naming constants # Checkpoint naming constants
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
from typing import Dict from typing import Dict
import torch import torch
from .reshape_3d_utils import model_3d_desc from .reshape_3d_utils import model_3d_desc
from .reshape_utils import (basic_folder_validation, from .reshape_utils import (basic_folder_validation, merge_state, partition_data, get_files, get_files_with_prefix)
merge_state,
partition_data,
get_files,
get_files_with_prefix)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
...@@ -24,19 +23,15 @@ CHECKPOINT_INFO_KEY = 'checkpoint_info' ...@@ -24,19 +23,15 @@ CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration' ITERATION_KEY = 'iteration'
SEQUENTIAL_LAYERS = [ SEQUENTIAL_LAYERS = [
'input_layernorm.weight', 'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
'input_layernorm.bias', 'post_attention_layernorm.bias', 'mlp.dense_4h_to_h.bias', 'position_embeddings.weight'
'self_attention.dense.bias',
'post_attention_layernorm.weight',
'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
] ]
LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1} LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1}
class DeepSpeedCheckpoint(object): class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir self.dir = dir
self._validate_folder(dir) self._validate_folder(dir)
...@@ -50,33 +45,24 @@ class DeepSpeedCheckpoint(object): ...@@ -50,33 +45,24 @@ class DeepSpeedCheckpoint(object):
self.layer_keys = self._get_layer_keys() self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys) self.layer_count = len(self.layer_keys)
self.tp_degree = self.zero_checkpoint.get_src_tp_degree( self.tp_degree = self.zero_checkpoint.get_src_tp_degree() if tp_degree is None else tp_degree
) if tp_degree is None else tp_degree self.pp_degree = self.zero_checkpoint.get_src_pp_degree() if pp_degree is None else pp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree( self.dp_degree = self.zero_checkpoint.get_src_dp_degree() if dp_degree is None else dp_degree
) if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree(
) if dp_degree is None else dp_degree
self.original_world_size = self.zero_checkpoint.get_src_tp_degree( self.original_world_size = self.zero_checkpoint.get_src_tp_degree() * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_dp_degree() ) * self.zero_checkpoint.get_src_dp_degree()
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree self.world_size = self.tp_degree * self.pp_degree * self.dp_degree
self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(), self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
self.zero_checkpoint.get_src_tp_degree()) self.zero_checkpoint.get_src_tp_degree())
self.old_2d_map.simple_init() self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel( self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_pp_degree=self.zero_checkpoint.get_src_pp_degree(), old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(), new_pp_degree=self.pp_degree,
new_pp_degree=self.pp_degree, new_tp_degree=self.tp_degree)
new_tp_degree=self.tp_degree)
if self.is_change_pp_degree() or self.is_change_tp_degree() or self.is_change_dp_degree():
if self.is_change_pp_degree() or self.is_change_tp_degree( self.zero_checkpoint.reshape(model_3d_desc(self.pp_degree, self.tp_degree, self.dp_degree))
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
model_3d_desc(self.pp_degree,
self.tp_degree,
self.dp_degree))
self.global_state = {} self.global_state = {}
...@@ -84,8 +70,7 @@ class DeepSpeedCheckpoint(object): ...@@ -84,8 +70,7 @@ class DeepSpeedCheckpoint(object):
self.pp_to_transformer_map = self._build_pp_transformer_map() self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map() self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map( self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
FINAL_LAYER_NORM_INDEX)
self._build_global_state() self._build_global_state()
def is_change_tp_degree(self): def is_change_tp_degree(self):
...@@ -131,9 +116,7 @@ class DeepSpeedCheckpoint(object): ...@@ -131,9 +116,7 @@ class DeepSpeedCheckpoint(object):
keys_to_ignore=[PARAM_SHAPES]) keys_to_ignore=[PARAM_SHAPES])
def get_zero_files(self, pp_index, tp_index, dp_index) -> list: def get_zero_files(self, pp_index, tp_index, dp_index) -> list:
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
tp_index=tp_index,
dp_index=dp_index)
def get_embedding_layer_id(self): def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX] return self.layer_keys[EMBEDDING_LAYER_INDEX]
...@@ -150,11 +133,7 @@ class DeepSpeedCheckpoint(object): ...@@ -150,11 +133,7 @@ class DeepSpeedCheckpoint(object):
def get_embedding_state(self, tp_index: int) -> Dict: def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys() assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [ sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
torch.load(fname,
map_location=torch.device('cpu'))
for fname in self.tp_to_embedding_map[tp_index]
]
sd = self._merge_state_dicts(sd_list) sd = self._merge_state_dicts(sd_list)
return sd return sd
...@@ -179,10 +158,7 @@ class DeepSpeedCheckpoint(object): ...@@ -179,10 +158,7 @@ class DeepSpeedCheckpoint(object):
assert tp_index < self.tp_degree assert tp_index < self.tp_degree
assert pp_index < self.pp_degree assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index) fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [ sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
merged_sd = None merged_sd = None
for sd in sd_list: for sd in sd_list:
...@@ -198,10 +174,7 @@ class DeepSpeedCheckpoint(object): ...@@ -198,10 +174,7 @@ class DeepSpeedCheckpoint(object):
assert pp_index < self.pp_degree assert pp_index < self.pp_degree
t_list = [] t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]: for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [ sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd = self._merge_state_dicts(sd_list) sd = self._merge_state_dicts(sd_list)
t_list.append(sd) t_list.append(sd)
return t_list return t_list
...@@ -212,8 +185,7 @@ class DeepSpeedCheckpoint(object): ...@@ -212,8 +185,7 @@ class DeepSpeedCheckpoint(object):
def get_final_norm_state(self, tp_index: int) -> Dict: def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys() assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
map_location=torch.device('cpu'))
return sd return sd
def get_final_norm_files(self, tp_index: int) -> list: def get_final_norm_files(self, tp_index: int) -> list:
...@@ -222,8 +194,7 @@ class DeepSpeedCheckpoint(object): ...@@ -222,8 +194,7 @@ class DeepSpeedCheckpoint(object):
def _build_tp_other_layer_map(self, layer_index: int): def _build_tp_other_layer_map(self, layer_index: int):
assert layer_index < len(self.layer_files) assert layer_index < len(self.layer_files)
layer_files = get_files_with_prefix(self.layer_files, layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
self.layer_keys[layer_index])
layer_file_partitions = partition_data(layer_files, self.tp_degree) layer_file_partitions = partition_data(layer_files, self.tp_degree)
data_map = {i: flist for i, flist in enumerate(layer_file_partitions)} data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
return data_map return data_map
...@@ -238,11 +209,7 @@ class DeepSpeedCheckpoint(object): ...@@ -238,11 +209,7 @@ class DeepSpeedCheckpoint(object):
data_map = {} data_map = {}
transformer_layers = self.layer_keys[1:-1] transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = { data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)}
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
for i in range(0,
self.pp_degree)
}
return data_map return data_map
def _dump_mapping(self, data_map, map_tag=None): def _dump_mapping(self, data_map, map_tag=None):
...@@ -308,10 +275,8 @@ class DeepSpeedCheckpoint(object): ...@@ -308,10 +275,8 @@ class DeepSpeedCheckpoint(object):
file_list = get_files(dir) file_list = get_files(dir)
for file_prefix in [ for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']:
MODEL_FILE_PREFIX,
LAYER_FILE_PREFIX,
f'{LAYER_FILE_PREFIX}01'
]:
ckpt_files = get_files_with_prefix(file_list, file_prefix) ckpt_files = get_files_with_prefix(file_list, file_prefix)
assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.' assert len(
ckpt_files
) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
from .reshape_utils import (get_files, # DeepSpeed Team
get_files_with_prefix,
partition_data, from .reshape_utils import (get_files, get_files_with_prefix, partition_data, get_zero_files)
get_zero_files)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
...@@ -15,6 +15,7 @@ DP_DIM = 'DP' ...@@ -15,6 +15,7 @@ DP_DIM = 'DP'
class model_3d_desc(object): class model_3d_desc(object):
def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1): def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1):
self.pp_degree = pp_degree self.pp_degree = pp_degree
self.tp_degree = tp_degree self.tp_degree = tp_degree
...@@ -33,8 +34,7 @@ class model_3d_desc(object): ...@@ -33,8 +34,7 @@ class model_3d_desc(object):
src_2d_size=self.pp_degree * self.tp_degree, src_2d_size=self.pp_degree * self.tp_degree,
dp_degree=self.dp_degree) dp_degree=self.dp_degree)
return unflatten_dp_dimension(meg_2d_map=flat_3d_map, return unflatten_dp_dimension(meg_2d_map=flat_3d_map, dp_degree=target_3d_desc.dp_degree)
dp_degree=target_3d_desc.dp_degree)
def get_desc(self): def get_desc(self):
return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})' return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'
...@@ -45,14 +45,11 @@ class model_3d_desc(object): ...@@ -45,14 +45,11 @@ class model_3d_desc(object):
def is_valid(self, pp_index, tp_index, dp_index): def is_valid(self, pp_index, tp_index, dp_index):
err_msg = [] err_msg = []
valid = True valid = True
for index, degree, dim_name in [ for index, degree, dim_name in [(pp_index, self.pp_degree, PP_DIM), (tp_index, self.tp_degree, TP_DIM),
(pp_index, self.pp_degree, PP_DIM), (dp_index, self.dp_degree, DP_DIM)]:
(tp_index, self.tp_degree, TP_DIM),
(dp_index, self.dp_degree, DP_DIM)]:
if index >= degree: if index >= degree:
valid = False valid = False
err_msg.append( err_msg.append(f'{dim_name} indexing error: index {index} >= degree {degree}')
f'{dim_name} indexing error: index {index} >= degree {degree}')
return valid, err_msg return valid, err_msg
...@@ -60,18 +57,15 @@ class model_3d_desc(object): ...@@ -60,18 +57,15 @@ class model_3d_desc(object):
err_msg = [] err_msg = []
if target_3d_desc.pp_degree > self.pp_degree: if target_3d_desc.pp_degree > self.pp_degree:
err_msg.append( err_msg.append(
f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}' f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}')
)
if target_3d_desc.tp_degree > self.tp_degree: if target_3d_desc.tp_degree > self.tp_degree:
err_msg.append( err_msg.append(
f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}' f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}')
)
if target_3d_desc.dp_degree > self.dp_degree: if target_3d_desc.dp_degree > self.dp_degree:
err_msg.append( err_msg.append(
f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}' f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}')
)
return len(err_msg) == 0, err_msg return len(err_msg) == 0, err_msg
...@@ -106,10 +100,7 @@ def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree): ...@@ -106,10 +100,7 @@ def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree):
def unflatten_dp_dimension(meg_2d_map, dp_degree): def unflatten_dp_dimension(meg_2d_map, dp_degree):
pp_degree = meg_2d_map.pp_degree pp_degree = meg_2d_map.pp_degree
tp_degree = meg_2d_map.tp_degree tp_degree = meg_2d_map.tp_degree
meg_2d_map_list = [ meg_2d_map_list = [meg_2d_parallel_map(pp_degree=pp_degree, tp_degree=tp_degree) for _ in range(dp_degree)]
meg_2d_parallel_map(pp_degree=pp_degree,
tp_degree=tp_degree) for _ in range(dp_degree)
]
for pp_index in range(pp_degree): for pp_index in range(pp_degree):
for tp_index in range(tp_degree): for tp_index in range(tp_degree):
flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index) flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .reshape_utils import partition_data from .reshape_utils import partition_data
class meg_2d_parallel_map(object): class meg_2d_parallel_map(object):
def __init__(self, pp_degree, tp_degree): def __init__(self, pp_degree, tp_degree):
self.pp_degree = pp_degree self.pp_degree = pp_degree
self.tp_degree = tp_degree self.tp_degree = tp_degree
...@@ -11,8 +15,7 @@ class meg_2d_parallel_map(object): ...@@ -11,8 +15,7 @@ class meg_2d_parallel_map(object):
def simple_init(self): def simple_init(self):
self.map = { self.map = {
self._make_key(i // self.tp_degree, self._make_key(i // self.tp_degree, i % self.tp_degree): [i]
i % self.tp_degree): [i]
for i in range(self.pp_degree * self.tp_degree) for i in range(self.pp_degree * self.tp_degree)
} }
...@@ -74,11 +77,7 @@ def _reshape_pp_dimension(old_2d_map, new_pp_degree): ...@@ -74,11 +77,7 @@ def _reshape_pp_dimension(old_2d_map, new_pp_degree):
return new_2d_map return new_2d_map
def reshape_meg_2d_parallel(old_pp_degree, def reshape_meg_2d_parallel(old_pp_degree, old_tp_degree, new_pp_degree, new_tp_degree, verbose=False):
old_tp_degree,
new_pp_degree,
new_tp_degree,
verbose=False):
assert new_pp_degree <= old_pp_degree assert new_pp_degree <= old_pp_degree
assert new_tp_degree <= old_tp_degree assert new_tp_degree <= old_tp_degree
...@@ -137,8 +136,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): ...@@ -137,8 +136,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
tensor_model_parallel_size = min(tp_size, world_size) tensor_model_parallel_size = min(tp_size, world_size)
pipeline_model_parallel_size = min(pp_size, world_size) pipeline_model_parallel_size = min(pp_size, world_size)
data_parallel_size = world_size // (tensor_model_parallel_size * data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
...@@ -158,10 +156,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): ...@@ -158,10 +156,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the model-parallel groups. # Build the model-parallel groups.
all_pp_group_ranks = [] all_pp_group_ranks = []
for i in range(data_parallel_size): for i in range(data_parallel_size):
ranks = [ ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_dp_group_ranks]
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_dp_group_ranks
]
all_pp_group_ranks.append(list(ranks)) all_pp_group_ranks.append(list(ranks))
print(f"PP", all_pp_group_ranks) print(f"PP", all_pp_group_ranks)
...@@ -169,8 +164,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None): ...@@ -169,8 +164,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
all_tp_group_ranks = [] all_tp_group_ranks = []
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
(i + 1) * tensor_model_parallel_size)
all_tp_group_ranks.append(list(ranks)) all_tp_group_ranks.append(list(ranks))
print(f"TP", all_tp_group_ranks) print(f"TP", all_tp_group_ranks)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
import torch import torch
...@@ -49,11 +52,7 @@ def partition_data(data_list, num_partitions): ...@@ -49,11 +52,7 @@ def partition_data(data_list, num_partitions):
num_elems = len(data_list) num_elems = len(data_list)
assert num_elems % num_partitions == 0 assert num_elems % num_partitions == 0
partition_size = num_elems // num_partitions partition_size = num_elems // num_partitions
partitions_list = [ partitions_list = [data_list[i:i + partition_size] for i in range(0, num_elems, partition_size)]
data_list[i:i + partition_size] for i in range(0,
num_elems,
partition_size)
]
return partitions_list return partitions_list
...@@ -76,9 +75,7 @@ def merge_state_dict(dict_a, dict_b, key_list): ...@@ -76,9 +75,7 @@ def merge_state_dict(dict_a, dict_b, key_list):
def merge_state_list(list_a, list_b, key_list): def merge_state_list(list_a, list_b, key_list):
if len(list_a) != len(list_b): if len(list_a) != len(list_b):
print(f'{_key_list_to_string(key_list)}') print(f'{_key_list_to_string(key_list)}')
raise ValueError( raise ValueError(f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}')
f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}'
)
return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)] return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)]
...@@ -87,8 +84,7 @@ def merge_state(state_a, state_b, key_list=[]): ...@@ -87,8 +84,7 @@ def merge_state(state_a, state_b, key_list=[]):
if type(state_a) != type(state_b): if type(state_a) != type(state_b):
key_list_string = _key_list_to_string(key_list) key_list_string = _key_list_to_string(key_list)
print(f'key_list = {key_list_string}') print(f'key_list = {key_list_string}')
raise ValueError( raise ValueError(f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
if type(state_a) in (dict, OrderedDict): if type(state_a) in (dict, OrderedDict):
return merge_state_dict(state_a, state_b, key_list) return merge_state_dict(state_a, state_b, key_list)
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import os import os
import torch import torch
import types import types
from .constants import (FP32_WEIGHT_KEY, from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM)
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
CAT_DIM)
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
...@@ -44,9 +43,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): ...@@ -44,9 +43,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved # the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP # weight is padding-free and we just need to add new padding depending on the target TP
# degree # degree
vocab_divisibility_padding_tensor = ckpt_dict.get( vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None)
VOCAB_DIVISIBILITY_PADDING_TENSOR,
None)
if vocab_divisibility_padding_tensor is not None: if vocab_divisibility_padding_tensor is not None:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree # In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so: # we can again derive that data by reverse engineering the target shapes like so:
...@@ -56,13 +53,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): ...@@ -56,13 +53,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padding_size = padded_target_vocab_size - full_hp_param.shape[0] padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad # Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param, full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
(0,
0,
0,
padding_size),
"constant",
0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else: else:
# Need to shrink or keep the same # Need to shrink or keep the same
...@@ -76,8 +67,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): ...@@ -76,8 +67,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
assert full_param_numel == tp_world_size * tp_slice_numel, \ assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)
key)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
...@@ -90,9 +80,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): ...@@ -90,9 +80,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
tp_hp_slice = tp_hp_slice.flatten() tp_hp_slice = tp_hp_slice.flatten()
lp_frag_address = hp_mapping.lp_fragment_address lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0, tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
lp_frag_address.start,
lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \ assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
...@@ -104,5 +92,4 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): ...@@ -104,5 +92,4 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
def enable_universal_checkpoint(param_list): def enable_universal_checkpoint(param_list):
for param in param_list: for param in param_list:
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)
param)
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
from .constants import (MODEL_FILE_PREFIX, from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX)
MODEL_FILE_SUFFIX,
OPTIM_FILE_SUFFIX,
ZERO_FILE_PREFIX)
def get_model_ckpt_name_for_rank(base_folder, mp_rank_str): def get_model_ckpt_name_for_rank(base_folder, mp_rank_str):
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from .constants import (BASE_OPTIMIZER_STATE, from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, PARTITION_COUNT)
GROUP_PADDINGS,
OPTIMIZER_STATE_DICT,
PARTITION_COUNT)
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state) from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
...@@ -15,6 +15,7 @@ GROUP_STATE_KEY = 'state' ...@@ -15,6 +15,7 @@ GROUP_STATE_KEY = 'state'
class ZeROCheckpoint(object): class ZeROCheckpoint(object):
def __init__(self, dir): def __init__(self, dir):
basic_folder_validation(dir) basic_folder_validation(dir)
self.dir = dir self.dir = dir
...@@ -49,12 +50,7 @@ class ZeROCheckpoint(object): ...@@ -49,12 +50,7 @@ class ZeROCheckpoint(object):
file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index) file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
return [self.file_list[idx] for idx in file_idx_list] return [self.file_list[idx] for idx in file_idx_list]
def get_state_for_rank(self, def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], strip_tensor_paddings=True):
pp_index,
tp_index,
dp_index,
keys_to_ignore=[],
strip_tensor_paddings=True):
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index) state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None merged_sd = None
for state_file in state_file_list: for state_file in state_file_list:
...@@ -111,10 +107,7 @@ class ZeROCheckpoint(object): ...@@ -111,10 +107,7 @@ class ZeROCheckpoint(object):
for state_name, state_value in group_state.items(): for state_name, state_value in group_state.items():
if torch.is_tensor(state_value): if torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key] raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value, group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
0,
0,
raw_length).clone()
def _clear_group_paddings(self, sd): def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS) group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
...@@ -144,5 +137,4 @@ class ZeROCheckpoint(object): ...@@ -144,5 +137,4 @@ class ZeROCheckpoint(object):
partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT) partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
if partition_counts: if partition_counts:
num_groups = len(partition_counts) num_groups = len(partition_counts)
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree] * num_groups
] * num_groups
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
import torch # DeepSpeed Team
from .utils import *
from deepspeed import utils
supported_torch_version = False
# See more details at: https://github.com/pytorch/pytorch/pull/48767
# The PG API in torch versions lesser than 1.8 are different so it is
# non-trivial to support both in the same API. We will just use the
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
if older_torch():
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
supported_torch_version = False
from torch.distributed import *
def get_world_group():
return group.WORLD
def get_global_rank(group, group_rank): from .utils import *
if hasattr(torch.distributed.distributed_c10d, "get_global_rank"): from .comm import *
from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
else:
from torch.distributed.distributed_c10d import _get_global_rank
return _get_global_rank(group, group_rank)
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import all_gather, get_world_size
from torch import chunk
output_tensors = list(chunk(output_tensor, get_world_size(group)))
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)
def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
from torch.distributed import reduce_scatter, get_world_size
from torch import chunk
input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
return reduce_scatter(output_tensor, input_tensor_lst, group=group)
def configure(deepspeed_config=None,
enabled=None,
prof_all=None,
prof_ops=None,
verbose=None):
utils.logger.warn(
"Communication logging is not supported in torch versions older than 1.8")
else:
supported_torch_version = True
from .comm import *
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
''' DeepSpeed Communication Backend. # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed Communication Backend.
In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distributed. Simply wrap torch.distributed for now. In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distributed. Simply wrap torch.distributed for now.
...@@ -8,18 +12,18 @@ In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distribute ...@@ -8,18 +12,18 @@ In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distribute
- MPI -- [EXPERIMENTAL] - MPI -- [EXPERIMENTAL]
- RCCL -- [EXPERIMENTAL] - RCCL -- [EXPERIMENTAL]
- GLOO -- [EXPERIMENTAL] - GLOO -- [EXPERIMENTAL]
# DS backend wrapper for torch.distributed [DEFAULT] # DS backend wrapper for torch.distributed [DEFAULT]
- T-NCCL -- [DEFAULT] - T-NCCL -- [DEFAULT]
- T-GLOO - T-GLOO
- T-MPI - T-MPI
'''
''' Backend is the base class # Backend is the base class
-- NcclBackend, MpiBackend, and TorchBackend are the main subclasses. TorchBackend is the only officially supported backend for now. -- NcclBackend, MpiBackend, and TorchBackend are the main subclasses. TorchBackend is the only officially supported backend for now.
''' """
class Backend(object): class Backend(object):
def __init__(self, name='backend', rank=0, size=1): def __init__(self, name='backend', rank=0, size=1):
self.name = name self.name = name
# The world size and rank of the world process group # The world size and rank of the world process group
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed Communication Package: deepspeed.comm DeepSpeed Communication Package: deepspeed.comm
deepspeed.comm deepspeed.comm
-- import and use deepspeeed.ops.comm -- import and use deepspeeed.ops.comm
-- use torch.distributed directly if both this package and torch.distributed use the same NCCL version -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
-- use custom collectives -- use custom collectives
-- can either use torch.dist or ds.ops.comm? -- can either use torch.dist or ds.ops.comm?
Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well. Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
deepspeed.comm API deepspeed.comm API
-- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility. -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
-- e.g. if a client code used -- e.g. if a client code used
from deepspeed import comm as dist from deepspeed import comm as dist
instead of instead of
import torch.distributed as dist import torch.distributed as dist
The code should work without breaking any of the public torch.distributed functionality The code should work without breaking any of the public torch.distributed functionality
Future: Future:
-- deepspeed groups API should be brought into ds.comm -- deepspeed groups API should be brought into ds.comm
""" """
...@@ -63,10 +59,6 @@ timer_summary = {} ...@@ -63,10 +59,6 @@ timer_summary = {}
comms_logger = CommsLogger() comms_logger = CommsLogger()
# Ensure we don't warn about base collectives more than once
has_warned_all_gather = False
has_warned_reduce_scatter = False
# Maintain objects of all initialized ds backends and assign them using the API functions in this file # Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend = None nccl_backend = None
mpi_backend = None mpi_backend = None
...@@ -110,12 +102,13 @@ def configure( ...@@ -110,12 +102,13 @@ def configure(
# Logging wrapper for timing ops # Logging wrapper for timing ops
def timed_op(func): def timed_op(func):
def log_wrapper(*args, **kwargs): def log_wrapper(*args, **kwargs):
# Add enabled flag so that overhead to each comm op is two if conditions at most # Add enabled flag so that overhead to each comm op is two if conditions at most
if comms_logger.enabled: if comms_logger.enabled:
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or ( if ('prof' in kwargs
'log_name' in kwargs and kwargs['prof']) or comms_logger.prof_all or ('log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops): and kwargs['log_name'] in comms_logger.prof_ops):
# Need func args for their defaults # Need func args for their defaults
func_args = get_default_args(func) func_args = get_default_args(func)
func_args.update(kwargs) func_args.update(kwargs)
...@@ -133,8 +126,7 @@ def timed_op(func): ...@@ -133,8 +126,7 @@ def timed_op(func):
if cdb.using_mpi: if cdb.using_mpi:
cdb.barrier() cdb.barrier()
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or ( if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs 'log_name' in kwargs and kwargs['log_name'] in comms_logger.prof_ops):
and kwargs['log_name'] in comms_logger.prof_ops):
log_name = get_debug_log_name(func_args, comms_logger.debug) log_name = get_debug_log_name(func_args, comms_logger.debug)
raw_name = func.__name__ raw_name = func.__name__
timers(log_name).stop() timers(log_name).stop()
...@@ -182,7 +174,8 @@ def destroy_process_group(group=None): ...@@ -182,7 +174,8 @@ def destroy_process_group(group=None):
def new_group(ranks): def new_group(ranks):
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.new_group(ranks) return cdb.new_group(ranks)
...@@ -196,14 +189,12 @@ def is_available() -> bool: ...@@ -196,14 +189,12 @@ def is_available() -> bool:
return True return True
def set_backend(backend): def set_backend(backend_name):
if not use_ds_backend: if not use_ds_backend:
utils.logger.error( utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality" "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
) )
raise RuntimeError( raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.')
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
global cdb global cdb
global nccl_backend global nccl_backend
...@@ -221,13 +212,7 @@ def set_backend(backend): ...@@ -221,13 +212,7 @@ def set_backend(backend):
@timed_op @timed_op
def broadcast(tensor, def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='broadcast', debug=get_caller_func()):
src,
group=None,
async_op=False,
prof=False,
log_name='broadcast',
debug=get_caller_func()):
global cdb global cdb
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
...@@ -241,17 +226,14 @@ def all_gather(tensor_list, ...@@ -241,17 +226,14 @@ def all_gather(tensor_list,
log_name='all_gather', log_name='all_gather',
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
return cdb.all_gather(tensor_list=tensor_list, return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
tensor=tensor,
group=group,
async_op=async_op)
def has_reduce_scatter_base(): def has_reduce_scatter_tensor():
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined' ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.has_reduce_scatter_base return cdb.has_reduce_scatter_tensor()
def reduce_scatter_fn(output_tensor, def reduce_scatter_fn(output_tensor,
...@@ -262,23 +244,21 @@ def reduce_scatter_fn(output_tensor, ...@@ -262,23 +244,21 @@ def reduce_scatter_fn(output_tensor,
prof=False, prof=False,
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
global has_warned_reduce_scatter assert cdb is not None and cdb.is_initialized(
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_base: if cdb.has_reduce_scatter_tensor():
return reduce_scatter_base(output_tensor, return reduce_scatter_tensor(output_tensor,
tensor, tensor,
op=op, op=op,
group=group, group=group,
async_op=async_op, async_op=async_op,
prof=prof, prof=prof,
debug=debug) debug=debug)
else: else:
if not has_warned_reduce_scatter: if get_rank() == 0:
utils.logger.warning( utils.logger.warning_once("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"unable to find torch.distributed._reduce_scatter_base. will fall back to " "torch.distributed.all_gather which will result in suboptimal performance. "
"torch.distributed.all_gather which will result in suboptimal performance. " "please consider upgrading your pytorch installation.")
"please consider upgrading your pytorch installation.")
has_warned_reduce_scatter = True
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group))) input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor, return reduce_scatter(output_tensor,
input_tensor_lst, input_tensor_lst,
...@@ -290,71 +270,54 @@ def reduce_scatter_fn(output_tensor, ...@@ -290,71 +270,54 @@ def reduce_scatter_fn(output_tensor,
@timed_op @timed_op
def reduce_scatter_base(output_tensor, def reduce_scatter_tensor(output_tensor,
tensor, tensor,
op=ReduceOp.SUM, op=ReduceOp.SUM,
group=None, group=None,
async_op=False, async_op=False,
prof=False, prof=False,
log_name='reduce_scatter_base', log_name='reduce_scatter_tensor',
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
return cdb.reduce_scatter_base(output_tensor=output_tensor, return cdb.reduce_scatter_tensor(output_tensor=output_tensor,
input_tensor=tensor, input_tensor=tensor,
op=op, op=op,
group=group, group=group,
async_op=async_op) async_op=async_op)
@timed_op @timed_op
def all_gather_base(output_tensor, def all_gather_into_tensor(output_tensor,
tensor, tensor,
group=None, group=None,
async_op=False, async_op=False,
prof=False, prof=False,
log_name='all_gather_base', log_name='all_gather_into_tensor',
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
return cdb.all_gather_base(output_tensor=output_tensor, return cdb.all_gather_into_tensor(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
input_tensor=tensor,
group=group,
async_op=async_op) def has_all_gather_into_tensor():
global cdb
assert cdb is not None and cdb.is_initialized(
def has_allgather_base(): ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
global cdb return cdb.has_all_gather_into_tensor()
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
return cdb.has_allgather_base def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(
def allgather_fn(output_tensor, ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
input_tensor, if cdb.has_all_gather_into_tensor():
group=None, return all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
async_op=False,
debug=get_caller_func()):
global cdb
global has_warned_all_gather
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_allgather_base:
return all_gather_base(output_tensor,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
else: else:
if not has_warned_all_gather and get_rank() == 0: if get_rank() == 0:
utils.logger.warning( utils.logger.warning_once("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
"unable to find torch.distributed._all_gather_base. will fall back to " "torch.distributed.all_gather which will result in suboptimal performance. "
"torch.distributed.all_gather which will result in suboptimal performance. " "please consider upgrading your pytorch installation.")
"please consider upgrading your pytorch installation.")
has_warned_all_gather = True
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group))) output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
return all_gather(output_tensors, return all_gather(output_tensors, input_tensor, group=group, async_op=async_op, debug=debug)
input_tensor,
group=group,
async_op=async_op,
debug=debug)
@timed_op @timed_op
...@@ -377,49 +340,25 @@ def all_to_all_single(output, ...@@ -377,49 +340,25 @@ def all_to_all_single(output,
@timed_op @timed_op
def send(tensor, def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()):
dst,
group=None,
tag=0,
prof=False,
log_name='send',
debug=get_caller_func()):
global cdb global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag) return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op @timed_op
def recv(tensor, def recv(tensor, src=None, group=None, tag=0, prof=False, log_name='recv', debug=get_caller_func()):
src=None,
group=None,
tag=0,
prof=False,
log_name='recv',
debug=get_caller_func()):
global cdb global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag) return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op @timed_op
def isend(tensor, def isend(tensor, dst, group=None, tag=0, prof=False, log_name='isend', debug=get_caller_func()):
dst,
group=None,
tag=0,
prof=False,
log_name='isend',
debug=get_caller_func()):
global cdb global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag) return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op @timed_op
def irecv(tensor, def irecv(tensor, src=None, group=None, tag=0, prof=False, log_name='irecv', debug=get_caller_func()):
src=None,
group=None,
tag=0,
prof=False,
log_name='irecv',
debug=get_caller_func()):
global cdb global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag) return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
...@@ -434,11 +373,7 @@ def gather(tensor, ...@@ -434,11 +373,7 @@ def gather(tensor,
log_name='gather', log_name='gather',
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
return cdb.gather(tensor=tensor, return cdb.gather(tensor=tensor, gather_list=gather_list, dst=dst, group=group, async_op=async_op)
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)
@timed_op @timed_op
...@@ -451,20 +386,11 @@ def scatter(tensor, ...@@ -451,20 +386,11 @@ def scatter(tensor,
log_name='scatter', log_name='scatter',
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
return cdb.scatter(tensor=tensor, return cdb.scatter(tensor=tensor, scatter_list=scatter_list, src=src, group=group, async_op=async_op)
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)
@timed_op @timed_op
def barrier(group=None, def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
async_op=False,
device_ids=None,
prof=False,
log_name='barrier',
debug=get_caller_func()):
global cdb global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids) return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
...@@ -511,11 +437,31 @@ def reduce_scatter(output, ...@@ -511,11 +437,31 @@ def reduce_scatter(output,
log_name='reduce_scatter', log_name='reduce_scatter',
debug=get_caller_func()): debug=get_caller_func()):
global cdb global cdb
return cdb.reduce_scatter(output=output, return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
input_list=input_list,
op=op,
group=group, def has_all_reduce_coalesced():
async_op=async_op) """"""
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_all_reduce_coalesced is not None, 'has_all_reduce_coalesced is not yet defined'
return cdb.has_all_reduce_coalesced
def has_coalescing_manager():
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_coalescing_manager is not None, 'has_coalescing_manager is not yet defined'
return cdb.has_coalescing_manager
def all_gather_coalesced(output_tensors, input_tensors, group=None, async_op=False):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.all_gather_coalesced(output_tensors, input_tensors, group=group, async_op=async_op)
@timed_op @timed_op
...@@ -535,9 +481,22 @@ def all_reduce(tensor, ...@@ -535,9 +481,22 @@ def all_reduce(tensor,
return cdb.all_reduce(tensor, op, group, async_op) return cdb.all_reduce(tensor, op, group, async_op)
@timed_op
def all_reduce_coalesced(tensors,
op=ReduceOp.SUM,
group=None,
async_op=False,
prof=False,
log_name='all_reduce',
debug=get_caller_func()):
global cbd
return cdb.all_reduce_coalesced(tensors, op, group, async_op)
def get_world_group(): def get_world_group():
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_group() return cdb.get_world_group()
...@@ -553,7 +512,8 @@ def get_world_size(group=None) -> int: ...@@ -553,7 +512,8 @@ def get_world_size(group=None) -> int:
""" """
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_size(group) return cdb.get_world_size(group)
...@@ -572,7 +532,8 @@ def get_rank(group=None): ...@@ -572,7 +532,8 @@ def get_rank(group=None):
-1, if not part of the group -1, if not part of the group
""" """
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_rank(group) return cdb.get_rank(group)
...@@ -585,13 +546,15 @@ def get_local_rank(): ...@@ -585,13 +546,15 @@ def get_local_rank():
local rank (= GPU device ID) local rank (= GPU device ID)
""" """
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return get_local_rank_from_launcher() return get_local_rank_from_launcher()
def get_global_rank(group=None, group_rank=0): def get_global_rank(group=None, group_rank=0):
global cdb global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_global_rank(group, group_rank) return cdb.get_global_rank(group, group_rank)
...@@ -640,9 +603,7 @@ def init_distributed(dist_backend=None, ...@@ -640,9 +603,7 @@ def init_distributed(dist_backend=None,
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose: if verbose:
utils.logger.info( utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
if in_aml() and not in_dlts(): if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose) patch_aml_env_for_torch_nccl_backend(verbose=verbose)
elif in_aws_sm(): elif in_aws_sm():
...@@ -658,9 +619,7 @@ def init_distributed(dist_backend=None, ...@@ -658,9 +619,7 @@ def init_distributed(dist_backend=None,
if dist_backend == None: if dist_backend == None:
dist_backend = get_accelerator().communication_backend_name() dist_backend = get_accelerator().communication_backend_name()
if int(os.getenv('RANK', '0')) == 0: if int(os.getenv('RANK', '0')) == 0:
utils.logger.info( utils.logger.info('Initializing TorchBackend in DeepSpeed with backend {}'.format(dist_backend))
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb # Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size) cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
...@@ -695,16 +654,12 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True) ...@@ -695,16 +654,12 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True)
if verbose: if verbose:
utils.logger.info( utils.logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}".
.format(os.environ['RANK'], format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['LOCAL_RANK'], os.environ['MASTER_PORT']))
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if cdb is not None and cdb.is_initialized(): if cdb is not None and cdb.is_initialized():
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format( assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, cdb.get_rank())
rank, cdb.get_rank())
assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, cdb.get_world_size()) world_size, cdb.get_world_size())
...@@ -731,8 +686,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True): ...@@ -731,8 +686,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
""" """
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(os.environ["WORLD_SIZE"])
os.environ["WORLD_SIZE"])
if not single_node: if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
...@@ -745,8 +699,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True): ...@@ -745,8 +699,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
if verbose: if verbose:
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format( utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
...@@ -754,10 +707,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True): ...@@ -754,10 +707,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
if verbose: if verbose:
utils.logger.info( utils.logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'], .format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT'])) os.environ['MASTER_PORT']))
...@@ -771,8 +721,5 @@ def patch_aws_sm_env_for_torch_nccl_backend(verbose=True): ...@@ -771,8 +721,5 @@ def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
if verbose: if verbose:
utils.logger.info( utils.logger.info(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'], .format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT'])) os.environ['MASTER_PORT']))
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
from pydantic import BaseModel from pydantic import BaseModel
from .constants import * from .constants import *
class CommsConfig(BaseModel): class CommsConfig(BaseModel):
class Config: class Config:
validate_all = True validate_all = True
validate_assignment = True validate_assignment = True
...@@ -25,6 +25,7 @@ class CommsLoggerConfig(CommsConfig): ...@@ -25,6 +25,7 @@ class CommsLoggerConfig(CommsConfig):
class DeepSpeedCommsConfig: class DeepSpeedCommsConfig:
def __init__(self, ds_config): def __init__(self, ds_config):
self.comms_logger_enabled = 'comms_logger' in ds_config self.comms_logger_enabled = 'comms_logger' in ds_config
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
NCCL_BACKEND = 'nccl' NCCL_BACKEND = 'nccl'
MPI_BACKEND = 'mpi' MPI_BACKEND = 'mpi'
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from deepspeed import utils from deepspeed import utils
from .utils import * from .utils import *
from .backend import * from .backend import *
from .comm import * from .comm import *
import torch
def is_torch_two():
TORCH_MAJOR = int(torch.__version__.split('.')[0])
if TORCH_MAJOR >= 2:
return True
else:
return False
def has_coalescing_manager():
has_c10d = hasattr(torch.distributed, 'distributed_c10d')
return has_c10d and hasattr(torch.distributed.distributed_c10d, '_coalescing_manager')
def has_all_reduce_coalesced():
return hasattr(torch.distributed, "all_reduce_coalesced")
def get_coalescing_manager(group, device, reqs):
if is_torch_two():
return torch.distributed.distributed_c10d._coalescing_manager(group, device=device, reqs=reqs)
else:
return torch.distributed.distributed_c10d._coalescing_manager(group, reqs)
class TorchBackend(Backend): class TorchBackend(Backend):
""" """
...@@ -16,17 +44,13 @@ class TorchBackend(Backend): ...@@ -16,17 +44,13 @@ class TorchBackend(Backend):
so no need to wrap all the functions. We can keep adding wrappers as so no need to wrap all the functions. We can keep adding wrappers as
needed. needed.
""" """
def __init__(self,
backend, def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
timeout,
init_method,
rank=-1,
world_size=-1,
name='torch'):
super(TorchBackend, self).__init__() super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch() self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_allgather_base = has_allgather_base() self.has_coalescing_manager = has_coalescing_manager()
self.has_reduce_scatter_base = has_reduce_scatter_base() self.all_gather_function = self.get_all_gather_function()
self.reduce_scatter_function = self.get_reduce_scatter_function()
self.initialized = True self.initialized = True
self.name = name self.name = name
# Future functionality to support ds.initialize() on a single GPU # Future functionality to support ds.initialize() on a single GPU
...@@ -35,6 +59,28 @@ class TorchBackend(Backend): ...@@ -35,6 +59,28 @@ class TorchBackend(Backend):
self.single_gpu_mode = True self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size) self.init_process_group(backend, timeout, init_method, rank, world_size)
@classmethod
def get_all_gather_function(self):
if hasattr(torch.distributed, "all_gather_into_tensor"):
return torch.distributed.all_gather_into_tensor
elif hasattr(torch.distributed, "_all_gather_base"):
return torch.distributed._all_gather_base
return None
@classmethod
def get_reduce_scatter_function(self):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
return torch.distributed.reduce_scatter_tensor
elif hasattr(torch.distributed, "_reduce_scatter_base"):
return torch.distributed._reduce_scatter_base
return None
def has_all_gather_into_tensor(self):
return self.all_gather_function is not None
def has_reduce_scatter_tensor(self):
return self.reduce_scatter_function is not None
def init_process_group(self, backend, timeout, init_method, rank, world_size): def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend, torch.distributed.init_process_group(backend,
...@@ -44,30 +90,24 @@ class TorchBackend(Backend): ...@@ -44,30 +90,24 @@ class TorchBackend(Backend):
world_size=world_size) world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi' self.using_mpi = torch.distributed.get_backend() == 'mpi'
def all_reduce(self, def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
tensor,
op=torch.distributed.ReduceOp.SUM,
group=None,
async_op=False):
op = self._reduce_op(op) op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
op=op,
group=group, def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
async_op=async_op) """ proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
"""
if not self.has_all_reduce_coalesced:
raise RuntimeError(f"Current torch version does not have all_reduce_coalesced "
f"api (torch.__version__: {torch.__version__})")
op = self._reduce_op(op)
return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op)
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce(tensor=tensor, return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
dst=dst,
op=self._reduce_op(op),
group=group,
async_op=async_op)
def reduce_scatter(self, def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
output,
input_list,
op=ReduceOp.SUM,
group=None,
async_op=False):
return torch.distributed.reduce_scatter(output=output, return torch.distributed.reduce_scatter(output=output,
input_list=input_list, input_list=input_list,
op=self._reduce_op(op), op=self._reduce_op(op),
...@@ -75,48 +115,57 @@ class TorchBackend(Backend): ...@@ -75,48 +115,57 @@ class TorchBackend(Backend):
async_op=async_op) async_op=async_op)
def broadcast(self, tensor, src, group=None, async_op=False): def broadcast(self, tensor, src, group=None, async_op=False):
return torch.distributed.broadcast(tensor=tensor, return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
src=src,
group=group,
async_op=async_op)
def all_gather(self, tensor_list, tensor, group=None, async_op=False): def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return torch.distributed.all_gather(tensor_list=tensor_list, return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
tensor=tensor,
def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_all_gather_into_tensor():
return self.all_gather_function(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group, group=group,
async_op=async_op) async_op=async_op)
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_allgather_base:
return torch.distributed.distributed_c10d._all_gather_base(
output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
else: else:
utils.logger.warning( utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
"unable to find torch.distributed._all_gather_base. will fall back to " "torch.distributed.all_gather which will result in suboptimal performance. "
"torch.distributed.reduce_scatter which will result in suboptimal performance. " "please consider upgrading your pytorch installation.")
"please consider upgrading your pytorch installation.")
pass pass
def reduce_scatter_base(self, def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False):
output_tensor, """"""
input_tensor, assert len(output_tensors) == len(input_tensors), ""
op=ReduceOp.SUM, if hasattr(torch.distributed.distributed_c10d, '_all_gather_base_coalesced'):
group=None, # customized PyTorch
async_op=False): return torch.distributed.distributed_c10d._all_gather_base_coalesced(output_tensors,
if self.has_reduce_scatter_base: input_tensors,
return torch.distributed._reduce_scatter_base(output_tensor, group=group,
input_tensor, async_op=async_op)
op=self._reduce_op(op), elif has_coalescing_manager():
group=group, reqs = []
async_op=async_op) with get_coalescing_manager(group, input_tensors[0].device, reqs):
for output, input in zip(output_tensors, input_tensors):
handle = torch.distributed.distributed_c10d.all_gather_into_tensor(output,
input,
group=group,
async_op=True)
reqs.append(handle)
if async_op:
return reqs[-1]
else:
reqs[-1].wait()
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():
return self.reduce_scatter_function(output_tensor,
input_tensor,
op=self._reduce_op(op),
group=group,
async_op=async_op)
else: else:
utils.logger.warning( utils.logger.warning("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"unable to find torch.distributed._reduce_scatter_base. will fall back to " "torch.distributed.reduce_scatter which will result in suboptimal performance. "
"torch.distributed.reduce_scatter which will result in suboptimal performance. " "please consider upgrading your pytorch installation.")
"please consider upgrading your pytorch installation.")
pass pass
def all_to_all_single(self, def all_to_all_single(self,
...@@ -159,25 +208,15 @@ class TorchBackend(Backend): ...@@ -159,25 +208,15 @@ class TorchBackend(Backend):
group=group, group=group,
async_op=async_op) async_op=async_op)
def barrier(self, def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
group=torch.distributed.GroupMember.WORLD,
async_op=False,
device_ids=None):
if group is None: if group is None:
group = torch.distributed.GroupMember.WORLD group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group, return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)
async_op=async_op,
device_ids=device_ids) def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
def monitored_barrier(self,
group=torch.distributed.GroupMember.WORLD,
timeout=None,
wait_all_ranks=False):
if group is None: if group is None:
group = torch.distributed.GroupMember.WORLD group = torch.distributed.GroupMember.WORLD
return torch.distributed.monitored_barrier(group=group, return torch.distributed.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
timeout=timeout,
wait_all_ranks=wait_all_ranks)
def get_rank(self, group=None): def get_rank(self, group=None):
return torch.distributed.get_rank(group=group) return torch.distributed.get_rank(group=group)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
import torch
import inspect import inspect
from deepspeed.utils import get_caller_func from deepspeed.utils import get_caller_func
def older_torch():
'''
Helper to lookup torch version. For versions less than 1.8, torch.dist
used torch.distributed.group.WORLD as the default group argument instead of None.
See more details at: https://github.com/pytorch/pytorch/pull/48767
'''
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
return True
else:
return False
def has_allgather_base():
'''
Helper to check if torch.distributed has _all_gather_base
'''
return hasattr(torch.distributed, "_all_gather_base")
def has_reduce_scatter_base():
'''
Helper to check if torch.distributed has _reduce_scatter_base
'''
return hasattr(torch.distributed, "_reduce_scatter_base")
def get_local_rank_from_launcher(): def get_local_rank_from_launcher():
# DeepSpeed launcher will set it so get from there # DeepSpeed launcher will set it so get from there
...@@ -84,11 +58,7 @@ def get_world_size_from_launcher(): ...@@ -84,11 +58,7 @@ def get_world_size_from_launcher():
def get_default_args(func): def get_default_args(func):
signature = inspect.signature(func) signature = inspect.signature(func)
return { return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
k: v.default
for k,
v in signature.parameters.items() if v.default is not inspect.Parameter.empty
}
# We need this hacky function since torch doesn't consistently name or place the input tensor args # We need this hacky function since torch doesn't consistently name or place the input tensor args
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .compress import init_compression, redundancy_clean from .compress import init_compression, redundancy_clean
from .scheduler import compression_scheduler from .scheduler import compression_scheduler
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
import math import math
...@@ -21,6 +24,7 @@ class QuantAct(nn.Module): ...@@ -21,6 +24,7 @@ class QuantAct(nn.Module):
Momentum for updating the activation quantization range. Momentum for updating the activation quantization range.
quant_mode : str, default 'symmetric' quant_mode : str, default 'symmetric'
""" """
def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'): def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
super(QuantAct, self).__init__() super(QuantAct, self).__init__()
...@@ -50,10 +54,8 @@ class QuantAct(nn.Module): ...@@ -50,10 +54,8 @@ class QuantAct(nn.Module):
self.x_min_max[1] = x_max self.x_min_max[1] = x_max
# if do not need momentum, please set self.act_range_momentum = 0 # if do not need momentum, please set self.act_range_momentum = 0
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * ( self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
1 - self.act_range_momentum) self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (
1 - self.act_range_momentum)
x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1]) x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
...@@ -61,6 +63,7 @@ class QuantAct(nn.Module): ...@@ -61,6 +63,7 @@ class QuantAct(nn.Module):
class Embedding_Compress(nn.Embedding): class Embedding_Compress(nn.Embedding):
def __init__(self, *kargs): def __init__(self, *kargs):
super(Embedding_Compress, self).__init__(*kargs) super(Embedding_Compress, self).__init__(*kargs)
self.weight.start_bits = None self.weight.start_bits = None
...@@ -71,17 +74,10 @@ class Embedding_Compress(nn.Embedding): ...@@ -71,17 +74,10 @@ class Embedding_Compress(nn.Embedding):
def extra_repr(self): def extra_repr(self):
return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format( return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
self.num_embeddings, self.num_embeddings, self.embedding_dim, self.weight.target_bits)
self.embedding_dim,
self.weight.target_bits) def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.weight.start_bits = start_bits self.weight.start_bits = start_bits
self.weight.target_bits = target_bits self.weight.target_bits = target_bits
self.weight.q_period = quantization_period self.weight.q_period = quantization_period
...@@ -105,31 +101,20 @@ class Embedding_Compress(nn.Embedding): ...@@ -105,31 +101,20 @@ class Embedding_Compress(nn.Embedding):
self.weight_quantize_num_groups = self.weight.size(0) self.weight_quantize_num_groups = self.weight.size(0)
def fix_weight_quantization(self): def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight, self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups).data self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False self.weight_quantization_enabled_in_forward = False
return None return None
def forward(self, input): def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled: if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight, weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups) self.weight_quantize_num_groups)
else: else:
weight = self.weight weight = self.weight
out = nn.functional.embedding(input, out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
weight, self.scale_grad_by_freq, self.sparse)
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse)
return out return out
...@@ -137,6 +122,7 @@ class LinearLayer_Compress(nn.Linear): ...@@ -137,6 +122,7 @@ class LinearLayer_Compress(nn.Linear):
""" """
Linear layer with compression. Linear layer with compression.
""" """
def __init__(self, *kargs, bias=True): def __init__(self, *kargs, bias=True):
super(LinearLayer_Compress, self).__init__(*kargs, bias=bias) super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
self.sparse_pruning_method = None self.sparse_pruning_method = None
...@@ -169,8 +155,7 @@ class LinearLayer_Compress(nn.Linear): ...@@ -169,8 +155,7 @@ class LinearLayer_Compress(nn.Linear):
mask = mask.to(self.weight.device) mask = mask.to(self.weight.device)
elif method == 'topk': elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size())) self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to( self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5)) init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None mask = None
else: else:
...@@ -209,11 +194,9 @@ class LinearLayer_Compress(nn.Linear): ...@@ -209,11 +194,9 @@ class LinearLayer_Compress(nn.Linear):
raise NotImplementedError raise NotImplementedError
else: else:
self.head_pruning_ratio = ratio self.head_pruning_ratio = ratio
self.head_pruning_scores = nn.Parameter(torch.Tensor( self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
1, self.num_heads)) # we apply the pruning to O matrix
self.num_heads)) # we apply the pruning to O matrix self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
self.head_pruning_scores.data = self.head_pruning_scores.data.to(
self.weight.device)
init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5)) init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
def fix_sparse_pruning_helper(self): def fix_sparse_pruning_helper(self):
...@@ -279,18 +262,17 @@ class LinearLayer_Compress(nn.Linear): ...@@ -279,18 +262,17 @@ class LinearLayer_Compress(nn.Linear):
start_bits = self.weight.start_bits start_bits = self.weight.start_bits
target_bits = self.weight.target_bits target_bits = self.weight.target_bits
q_period = self.weight.q_period q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape).t()) self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
-1)[mask.view(-1), :].reshape(-1,
shape).t())
self.weight.start_bits = start_bits self.weight.start_bits = start_bits
self.weight.target_bits = target_bits self.weight.target_bits = target_bits
self.weight.q_period = q_period self.weight.q_period = q_period
else: else:
shape = self.weight.size() shape = self.weight.size()
self.weight.data = (self.weight.data.t().reshape(self.num_heads, self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
-1) * shape[1], shape[0]).t()
mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
if self.head_pruning_method == 'topk': if self.head_pruning_method == 'topk':
del self.head_pruning_scores del self.head_pruning_scores
...@@ -316,37 +298,26 @@ class LinearLayer_Compress(nn.Linear): ...@@ -316,37 +298,26 @@ class LinearLayer_Compress(nn.Linear):
if self.sparse_pruning_method == 'l1': if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device) return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk': elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores, return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
self.sparse_pruning_ratio,
False)
else: else:
raise NotImplementedError raise NotImplementedError
if pruning_type == 'row': if pruning_type == 'row':
if self.row_pruning_method == 'l1': if self.row_pruning_method == 'l1':
return self.row_pruning_mask.to(self.weight.device) return self.row_pruning_mask.to(self.weight.device)
elif self.row_pruning_method == 'topk': elif self.row_pruning_method == 'topk':
return TopKBinarizer.apply(self.row_mask_scores, return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
self.row_pruning_ratio,
False)
else: else:
raise NotImplementedError raise NotImplementedError
elif pruning_type == 'head': elif pruning_type == 'head':
if self.head_pruning_method == 'topk': if self.head_pruning_method == 'topk':
return TopKBinarizer.apply(self.head_pruning_scores, return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
self.head_pruning_ratio,
False)
else: else:
raise NotImplementedError raise NotImplementedError
else: else:
raise NotImplementedError raise NotImplementedError
def enable_weight_quantization(self, def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
start_bits, weight_quantization_enabled_in_forward, quantization_type, num_groups):
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.weight.start_bits = start_bits self.weight.start_bits = start_bits
self.weight.target_bits = target_bits self.weight.target_bits = target_bits
self.weight.q_period = quantization_period self.weight.q_period = quantization_period
...@@ -369,10 +340,7 @@ class LinearLayer_Compress(nn.Linear): ...@@ -369,10 +340,7 @@ class LinearLayer_Compress(nn.Linear):
self.weight_quantize_num_groups = num_groups self.weight_quantize_num_groups = num_groups
def fix_weight_quantization(self): def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight, self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups).data self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False self.weight_quantization_enabled_in_forward = False
return None return None
...@@ -391,18 +359,12 @@ class LinearLayer_Compress(nn.Linear): ...@@ -391,18 +359,12 @@ class LinearLayer_Compress(nn.Linear):
def head_pruning_reshape(self, w, mask): def head_pruning_reshape(self, w, mask):
shape = w.shape shape = w.shape
return (w.t().reshape(self.num_heads, return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()
-1) * mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
def forward(self, input, skip_bias_add=False): def forward(self, input, skip_bias_add=False):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled: if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight, weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups) self.weight_quantize_num_groups)
bias = self.bias bias = self.bias
else: else:
...@@ -428,11 +390,7 @@ class LinearLayer_Compress(nn.Linear): ...@@ -428,11 +390,7 @@ class LinearLayer_Compress(nn.Linear):
num_groups = input.numel() // input.size(-1) num_groups = input.numel() // input.size(-1)
else: else:
num_groups = 1 num_groups = 1
input = self.activation_quantizer(input, input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
self.activation_quantization_bits,
None,
None,
num_groups)
if skip_bias_add: if skip_bias_add:
# used for mpu linear layers # used for mpu linear layers
...@@ -447,6 +405,7 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -447,6 +405,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
""" """
Conv2D layer with compression. Conv2D layer with compression.
""" """
def __init__(self, *kargs): def __init__(self, *kargs):
super(Conv2dLayer_Compress, self).__init__(*kargs) super(Conv2dLayer_Compress, self).__init__(*kargs)
self.sparse_pruning_method = None self.sparse_pruning_method = None
...@@ -478,10 +437,8 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -478,10 +437,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
output = s.format(**self.__dict__) output = s.format(**self.__dict__)
return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format( return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
self.sparse_pruning_method is not None, self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
self.channel_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
self.activation_quantization_method is not None,
self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method): def enable_sparse_pruning(self, ratio, method):
self.sparse_pruning_ratio = ratio self.sparse_pruning_ratio = ratio
...@@ -493,8 +450,7 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -493,8 +450,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask = mask.to(self.weight.device) mask = mask.to(self.weight.device)
elif method == 'topk': elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size())) self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to( self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5)) init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None mask = None
else: else:
...@@ -514,13 +470,8 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -514,13 +470,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask = mask.view(-1, 1, 1, 1) mask = mask.view(-1, 1, 1, 1)
mask = mask.to(self.weight.device) mask = mask.to(self.weight.device)
elif method == 'topk': elif method == 'topk':
self.channel_mask_scores = nn.Parameter( self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
torch.Tensor(self.weight.size(0), self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
1,
1,
1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(
self.weight.device)
init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5)) init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
mask = None mask = None
else: else:
...@@ -579,39 +530,27 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -579,39 +530,27 @@ class Conv2dLayer_Compress(nn.Conv2d):
if self.sparse_pruning_method == 'l1': if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device) return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk': elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores, return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
self.sparse_pruning_ratio,
False)
else: else:
raise NotImplementedError raise NotImplementedError
elif pruning_type == 'channel': elif pruning_type == 'channel':
if self.channel_pruning_method == 'l1': if self.channel_pruning_method == 'l1':
return self.channel_pruning_mask.to(self.weight.device) return self.channel_pruning_mask.to(self.weight.device)
elif self.channel_pruning_method == 'topk': elif self.channel_pruning_method == 'topk':
return TopKBinarizer.apply(self.channel_mask_scores, return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
self.channel_pruning_ratio,
False)
else: else:
raise NotImplementedError raise NotImplementedError
else: else:
raise NotImplementedError raise NotImplementedError
def fix_weight_quantization(self): def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight, self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups).data self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False self.weight_quantization_enabled_in_forward = False
return None return None
def enable_weight_quantization(self, def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
start_bits, weight_quantization_enabled_in_forward, quantization_type, num_groups):
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
self.weight.start_bits = start_bits self.weight.start_bits = start_bits
self.weight.target_bits = target_bits self.weight.target_bits = target_bits
self.weight.q_period = quantization_period self.weight.q_period = quantization_period
...@@ -642,10 +581,7 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -642,10 +581,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
def forward(self, input): def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled: if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight, weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight.target_bits,
None,
None,
self.weight_quantize_num_groups) self.weight_quantize_num_groups)
bias = self.bias bias = self.bias
else: else:
...@@ -667,22 +603,13 @@ class Conv2dLayer_Compress(nn.Conv2d): ...@@ -667,22 +603,13 @@ class Conv2dLayer_Compress(nn.Conv2d):
num_groups = input.numel() // input[0].numel() num_groups = input.numel() // input[0].numel()
else: else:
num_groups = 1 num_groups = 1
input = self.activation_quantizer(input, input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
self.activation_quantization_bits,
None,
None,
num_groups)
return nn.functional.conv2d(input, return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups)
class BNLayer_Compress(nn.BatchNorm2d): class BNLayer_Compress(nn.BatchNorm2d):
def fix_channel_pruning_helper(self, mask, dim_reduction=True): def fix_channel_pruning_helper(self, mask, dim_reduction=True):
self.weight = nn.Parameter(self.weight.data[mask.view(-1)]) self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
self.bias = nn.Parameter(self.bias.data[mask.view(-1)]) self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
...@@ -770,6 +697,7 @@ def _gather(input_): ...@@ -770,6 +697,7 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return input_ return input_
...@@ -781,6 +709,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function): ...@@ -781,6 +709,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region.""" """All-redcue the input from the model parallel region."""
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _reduce(input_) return _reduce(input_)
...@@ -792,6 +721,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -792,6 +721,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function): class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """Split the input and keep only the corresponding chuck to the rank."""
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split(input_)
...@@ -803,6 +733,7 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -803,6 +733,7 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """Gather the input from model parallel region and concatinate."""
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather(input_)
...@@ -834,13 +765,8 @@ def gather_from_model_parallel_region(input_): ...@@ -834,13 +765,8 @@ def gather_from_model_parallel_region(input_):
class ColumnParallelLinear_Compress(LinearLayer_Compress): class ColumnParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu, def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
input_size,
output_size,
bias=True,
gather_output=True,
skip_bias_add=False):
# Keep input parameters # Keep input parameters
global g_mpu global g_mpu
g_mpu = mpu g_mpu = mpu
...@@ -854,10 +780,7 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress): ...@@ -854,10 +780,7 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
assert output_size % world_size == 0 assert output_size % world_size == 0
self.output_size_per_partition = output_size // world_size self.output_size_per_partition = output_size // world_size
super(ColumnParallelLinear_Compress, super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)
self).__init__(self.input_size,
self.output_size_per_partition,
bias=bias)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
...@@ -877,13 +800,8 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress): ...@@ -877,13 +800,8 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
class RowParallelLinear_Compress(LinearLayer_Compress): class RowParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu, def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
input_size,
output_size,
bias=True,
input_is_parallel=False,
skip_bias_add=False):
# Keep input parameters # Keep input parameters
global g_mpu global g_mpu
g_mpu = mpu g_mpu = mpu
...@@ -897,10 +815,7 @@ class RowParallelLinear_Compress(LinearLayer_Compress): ...@@ -897,10 +815,7 @@ class RowParallelLinear_Compress(LinearLayer_Compress):
assert input_size % world_size == 0 assert input_size % world_size == 0
self.input_size_per_partition = input_size // world_size self.input_size_per_partition = input_size // world_size
super(RowParallelLinear_Compress, super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)
self).__init__(self.input_size_per_partition,
self.output_size,
bias=bias)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import re import re
from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible from .helper import compression_preparation, fix_compression, recursive_getattr, is_module_compressible
...@@ -13,21 +16,13 @@ def check_deepspeed_config(config): ...@@ -13,21 +16,13 @@ def check_deepspeed_config(config):
if isinstance(config, dict): if isinstance(config, dict):
return config return config
elif os.path.exists(config): elif os.path.exists(config):
return json.load(open(config, return json.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else: else:
raise ValueError( raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}" f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}")
)
def get_module_name(group_name, def get_module_name(group_name, model, key_word, exist_module_name, mpu=None, verbose=True):
model,
key_word,
exist_module_name,
mpu=None,
verbose=True):
''' '''
get the associated module name from the model based on the key_word provided by users get the associated module name from the model based on the key_word provided by users
''' '''
...@@ -40,8 +35,7 @@ def get_module_name(group_name, ...@@ -40,8 +35,7 @@ def get_module_name(group_name,
if name in exist_module_name and verbose: if name in exist_module_name and verbose:
# logger.warning # logger.warning
raise ValueError( raise ValueError(
f"{name} is already added to compression, please check your config file for {group_name}." f"{name} is already added to compression, please check your config file for {group_name}.")
)
if name not in exist_module_name: if name not in exist_module_name:
exist_module_name.add(name) exist_module_name.add(name)
return_module_name.append(name) return_module_name.append(name)
...@@ -56,8 +50,7 @@ def get_compress_methods(model, compress_methods, mpu=None): ...@@ -56,8 +50,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
continue continue
# for loop different methods, i.e., weight quantization, activation quantization etc # for loop different methods, i.e., weight quantization, activation quantization etc
exist_module_name = set() exist_module_name = set()
shared_parameters = method_content[ shared_parameters = method_content[SHARED_PARAMETERS] # get all the shared parameters
SHARED_PARAMETERS] # get all the shared parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items(): for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
# for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc # for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
module_name_list = [] module_name_list = []
...@@ -65,8 +58,13 @@ def get_compress_methods(model, compress_methods, mpu=None): ...@@ -65,8 +58,13 @@ def get_compress_methods(model, compress_methods, mpu=None):
if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]: if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]:
# this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them # this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
# otherwise we just mask those as zeros # otherwise we just mask those as zeros
for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE], method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]): for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE],
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu) method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
module_name, exist_module_name = get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=mpu)
module_name_list.append(module_name) module_name_list.append(module_name)
tmp_related_module_name_list = [] tmp_related_module_name_list = []
for rkw in related_key_words: for rkw in related_key_words:
...@@ -76,7 +74,11 @@ def get_compress_methods(model, compress_methods, mpu=None): ...@@ -76,7 +74,11 @@ def get_compress_methods(model, compress_methods, mpu=None):
related_module_name_list.append(tmp_related_module_name_list) related_module_name_list.append(tmp_related_module_name_list)
else: else:
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]: for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu) module_name, exist_module_name = get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=mpu)
module_name_list.append(module_name) module_name_list.append(module_name)
if module_name_list: if module_name_list:
...@@ -85,13 +87,7 @@ def get_compress_methods(model, compress_methods, mpu=None): ...@@ -85,13 +87,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
**(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)), **(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)),
**shared_parameters **shared_parameters
} }
compression_item = [ compression_item = [module_name_list, related_module_name_list, {method: combined_method_parameters}]
module_name_list,
related_module_name_list,
{
method: combined_method_parameters
}
]
layer_added_compress_methods.append(compression_item) layer_added_compress_methods.append(compression_item)
return layer_added_compress_methods return layer_added_compress_methods
...@@ -118,9 +114,7 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None): ...@@ -118,9 +114,7 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
assert teacher_model is not None, "Teacher model is required for layer reduction" assert teacher_model is not None, "Teacher model is required for layer reduction"
student_initialization(c_model, teacher_model, deepspeed_config) student_initialization(c_model, teacher_model, deepspeed_config)
layer_added_compress_methods = get_compress_methods(c_model, layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
compress_methods,
mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu) compression_preparation(c_model, layer_added_compress_methods, mpu)
return model return model
...@@ -143,31 +137,20 @@ def redundancy_clean(model, deepspeed_config, mpu=None): ...@@ -143,31 +137,20 @@ def redundancy_clean(model, deepspeed_config, mpu=None):
else: else:
c_model = model c_model = model
layer_added_compress_methods_tmp = get_compress_methods(c_model, layer_added_compress_methods_tmp = get_compress_methods(c_model, compress_methods, mpu=mpu)
compress_methods,
mpu=mpu)
# sort methods # sort methods
order_list = [ order_list = [
WEIGHT_QUANTIZATION, WEIGHT_QUANTIZATION, SPARSE_PRUNING, ROW_PRUNING, HEAD_PRUNING, CHANNEL_PRUNING, ACTIVATION_QUANTIZATION
SPARSE_PRUNING,
ROW_PRUNING,
HEAD_PRUNING,
CHANNEL_PRUNING,
ACTIVATION_QUANTIZATION
] ]
layer_added_compress_methods = sorted( layer_added_compress_methods = sorted(layer_added_compress_methods_tmp,
layer_added_compress_methods_tmp, key=lambda x: order_list.index(list(x[2].keys())[0]))
key=lambda x: order_list.index(list(x[2].keys())[0]))
for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods: for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods:
stored_mask = [] stored_mask = []
need_mask = True if related_module_name_lists else False need_mask = True if related_module_name_lists else False
for i, mnl in enumerate(module_name_lists): for i, mnl in enumerate(module_name_lists):
for module_name in mnl: for module_name in mnl:
mask = fix_compression(c_model, mask = fix_compression(c_model, module_name, compression_technique, dim_reduction=need_mask)
module_name,
compression_technique,
dim_reduction=need_mask)
if need_mask: if need_mask:
stored_mask.append(mask) stored_mask.append(mask)
if need_mask: if need_mask:
...@@ -219,10 +202,8 @@ def student_initialization(student_model, teacher_model, deepspeed_config): ...@@ -219,10 +202,8 @@ def student_initialization(student_model, teacher_model, deepspeed_config):
''' '''
assert len(student_layer) == len(teacher_layer) assert len(student_layer) == len(teacher_layer)
for s_name, t_name in zip(student_layer, teacher_layer): for s_name, t_name in zip(student_layer, teacher_layer):
s_module = recursive_getattr(student_model, s_module = recursive_getattr(student_model, module_name_prefix + '.' + str(s_name))
module_name_prefix + '.' + str(s_name)) t_module = recursive_getattr(teacher_model, module_name_prefix + '.' + str(t_name))
t_module = recursive_getattr(teacher_model,
module_name_prefix + '.' + str(t_name))
for s_param, t_param in zip(s_module.parameters(), t_module.parameters()): for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
s_param.data.copy_(t_param.data) s_param.data.copy_(t_param.data)
for name in other_module_name: for name in other_module_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment