Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .constants import *
import copy
......@@ -36,9 +39,7 @@ def get_layer_reduction(param_dict):
def get_layer_reduction_enabled(param_dict):
if LAYER_REDUCTION in param_dict.keys():
return get_scalar_param(param_dict[LAYER_REDUCTION],
LAYER_REDUCTION_ENABLED,
LAYER_REDUCTION_ENABLED_DEFAULT)
return get_scalar_param(param_dict[LAYER_REDUCTION], LAYER_REDUCTION_ENABLED, LAYER_REDUCTION_ENABLED_DEFAULT)
else:
return False
......@@ -70,7 +71,8 @@ def get_weight_quantization(param_dict):
output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
return output
......@@ -79,51 +81,38 @@ def get_weight_quantization_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_KERNEL,
WEIGHT_QUANTIZE_KERNEL_DEFAULT)
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_VERBOSE,
WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict,
WEIGHT_QUANTIZE_TYPE,
output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_KERNEL,
WEIGHT_QUANTIZE_KERNEL_DEFAULT)
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_VERBOSE,
WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_TYPE,
WEIGHT_QUANTIZE_TYPE_DEFAULT)
output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_ROUNDING,
WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(sub_param_dict,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
assert output[WEIGHT_QUANTIZE_TYPE] in [
WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC
], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ROUNDING,
WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
assert output[WEIGHT_QUANTIZE_ROUNDING] in [
WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING
], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
WEIGHT_QUANTIZE_CHANGE_RATIO,
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_CHANGE_RATIO,
WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
else:
output[
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
else:
output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
......@@ -133,8 +122,7 @@ def get_weight_quantization_shared_parameters(param_dict):
output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
output[
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
return output
......@@ -144,27 +132,21 @@ def get_weight_quantization_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(
group_dict,
WEIGHT_QUANTIZATION_PERIOD,
WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(
), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(
), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(group_dict, WEIGHT_QUANTIZATION_PERIOD,
WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
......@@ -172,19 +154,15 @@ def get_weight_quantization_different_groups(param_dict):
def get_activation_quantization(param_dict):
output = {}
if ACTIVATION_QUANTIZATION not in param_dict.keys():
param_dict[ACTIVATION_QUANTIZATION] = {
SHARED_PARAMETERS: {},
DIFFERENT_GROUPS: {}
}
param_dict[ACTIVATION_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
# shared parameters
output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(
sub_param_dict)
output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(
sub_param_dict)
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(sub_param_dict)
return output
......@@ -192,30 +170,26 @@ def get_activation_quantization_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZATION_ENABLED,
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_TYPE,
ACTIVATION_QUANTIZE_TYPE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_RANGE,
ACTIVATION_QUANTIZE_RANGE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZATION_ENABLED,
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_TYPE,
ACTIVATION_QUANTIZE_TYPE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_TYPE] in [
ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC
], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_RANGE,
ACTIVATION_QUANTIZE_RANGE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_RANGE] in [
ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC
], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
else:
output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
output[
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
return output
......@@ -224,22 +198,17 @@ def get_activation_quantization_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(
), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
......@@ -253,7 +222,8 @@ def get_sparse_pruning(param_dict):
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output
......@@ -262,18 +232,15 @@ def get_sparse_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(
sub_param_dict,
SPARSE_PRUNING_ENABLED,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
SPARSE_PRUNING_METHOD,
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [
SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK
], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
......@@ -286,22 +253,17 @@ def get_sparse_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
......@@ -315,7 +277,8 @@ def get_row_pruning(param_dict):
output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
return output
......@@ -324,17 +287,14 @@ def get_row_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
ROW_PRUNING_ENABLED,
output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, ROW_PRUNING_ENABLED,
ROW_PRUNING_ENABLED_DEFAULT)
output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
ROW_PRUNING_METHOD,
ROW_PRUNING_METHOD_DEFAULT)
assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
ROW_PRUNING_SCHEDULE_OFFSET,
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict, ROW_PRUNING_METHOD, ROW_PRUNING_METHOD_DEFAULT)
assert output[ROW_PRUNING_METHOD] in [
ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK
], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, ROW_PRUNING_SCHEDULE_OFFSET,
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
......@@ -347,22 +307,17 @@ def get_row_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
......@@ -375,7 +330,8 @@ def get_head_pruning(param_dict):
output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
return output
......@@ -384,19 +340,18 @@ def get_head_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
HEAD_PRUNING_ENABLED,
output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, HEAD_PRUNING_ENABLED,
HEAD_PRUNING_ENABLED_DEFAULT)
output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
HEAD_PRUNING_METHOD,
output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict, HEAD_PRUNING_METHOD,
HEAD_PRUNING_METHOD_DEFAULT)
assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
HEAD_PRUNING_SCHEDULE_OFFSET,
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
assert output[HEAD_PRUNING_METHOD] in [
HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK
], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, HEAD_PRUNING_SCHEDULE_OFFSET,
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
if output[HEAD_PRUNING_ENABLED]:
assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(
), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
else:
output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
......@@ -410,22 +365,17 @@ def get_head_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}"
assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(
), f"dense_ratio must be specified for head pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
......@@ -438,7 +388,8 @@ def get_channel_pruning(param_dict):
output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
return output
......@@ -447,19 +398,15 @@ def get_channel_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_ENABLED,
CHANNEL_PRUNING_ENABLED_DEFAULT)
output[CHANNEL_PRUNING_METHOD] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_METHOD,
CHANNEL_PRUNING_METHOD_DEFAULT)
assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_SCHEDULE_OFFSET,
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_ENABLED,
CHANNEL_PRUNING_ENABLED_DEFAULT)
output[CHANNEL_PRUNING_METHOD] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_METHOD,
CHANNEL_PRUNING_METHOD_DEFAULT)
assert output[CHANNEL_PRUNING_METHOD] in [
CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK
], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_SCHEDULE_OFFSET,
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
......@@ -472,21 +419,16 @@ def get_channel_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
#########################################
# Compression Methods
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
......@@ -59,23 +62,17 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
need_bias = True
# Initialize the new module
if isinstance(old_module,
LinearLayer_Compress) or isinstance(old_module,
torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress) or isinstance(old_module, torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress):
new_module = old_module
else:
new_module = LinearLayer_Compress(old_module.in_features,
old_module.out_features,
bias=need_bias).to(
device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module = LinearLayer_Compress(old_module.in_features, old_module.out_features,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module,
Conv2dLayer_Compress) or isinstance(old_module,
torch.nn.Conv2d):
elif isinstance(old_module, Conv2dLayer_Compress) or isinstance(old_module, torch.nn.Conv2d):
if isinstance(old_module, Conv2dLayer_Compress):
new_module = old_module
else:
......@@ -86,60 +83,48 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module, torch.nn.BatchNorm2d):
new_module = BNLayer_Compress(old_module.num_features,
old_module.eps,
old_module.momentum,
old_module.affine,
old_module.track_running_stats).to(
old_module.weight.device,
old_module.weight.dtype)
new_module = BNLayer_Compress(old_module.num_features, old_module.eps, old_module.momentum, old_module.affine,
old_module.track_running_stats).to(old_module.weight.device,
old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
new_module.running_mean.data = old_module.running_mean.data
new_module.running_var.data = old_module.running_var.data
elif isinstance(old_module,
Embedding_Compress) or isinstance(old_module,
torch.nn.Embedding):
elif isinstance(old_module, Embedding_Compress) or isinstance(old_module, torch.nn.Embedding):
if isinstance(old_module, Embedding_Compress):
new_module = old_module
else:
new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
elif mpu is not None and (isinstance(old_module,
ColumnParallelLinear_Compress)
or isinstance(old_module,
mpu.ColumnParallelLinear)):
elif mpu is not None and (isinstance(old_module, ColumnParallelLinear_Compress)
or isinstance(old_module, mpu.ColumnParallelLinear)):
if isinstance(old_module, ColumnParallelLinear_Compress):
new_module = old_module
else:
new_module = ColumnParallelLinear_Compress(
mpu,
old_module.input_size,
old_module.output_size,
gather_output=old_module.gather_output,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module = ColumnParallelLinear_Compress(mpu,
old_module.input_size,
old_module.output_size,
gather_output=old_module.gather_output,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif mpu is not None and (isinstance(old_module,
RowParallelLinear_Compress)
or isinstance(old_module,
mpu.RowParallelLinear)):
elif mpu is not None and (isinstance(old_module, RowParallelLinear_Compress)
or isinstance(old_module, mpu.RowParallelLinear)):
if isinstance(old_module, RowParallelLinear_Compress):
new_module = old_module
else:
new_module = RowParallelLinear_Compress(
mpu,
old_module.input_size,
old_module.output_size,
input_is_parallel=old_module.input_is_parallel,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module = RowParallelLinear_Compress(mpu,
old_module.input_size,
old_module.output_size,
input_is_parallel=old_module.input_is_parallel,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
......@@ -150,39 +135,30 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
for k, v in compression_technique.items():
if k == SPARSE_PRUNING:
if v[SPARSE_PRUNING_ENABLED]:
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO],
v[SPARSE_PRUNING_METHOD])
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], v[SPARSE_PRUNING_METHOD])
elif k == ROW_PRUNING:
if v[ROW_PRUNING_ENABLED]:
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO],
v[ROW_PRUNING_METHOD])
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], v[ROW_PRUNING_METHOD])
elif k == HEAD_PRUNING:
if v[HEAD_PRUNING_ENABLED]:
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO],
v[HEAD_PRUNING_METHOD],
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], v[HEAD_PRUNING_METHOD],
v[HEAD_PRUNING_NUM_HEADS])
elif k == ACTIVATION_QUANTIZATION:
if v[ACTIVATION_QUANTIZATION_ENABLED]:
new_module.enable_activation_quantization(
v[ACTIVATION_QUANTIZE_BITS],
v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
new_module.enable_activation_quantization(v[ACTIVATION_QUANTIZE_BITS], v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
elif k == WEIGHT_QUANTIZATION:
if v[WEIGHT_QUANTIZE_ENABLED]:
new_module.enable_weight_quantization(
v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_TYPE],
v[WEIGHT_QUANTIZE_GROUPS])
new_module.enable_weight_quantization(v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_TYPE], v[WEIGHT_QUANTIZE_GROUPS])
elif k == CHANNEL_PRUNING:
if v[CHANNEL_PRUNING_ENABLED]:
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO],
v[CHANNEL_PRUNING_METHOD])
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], v[CHANNEL_PRUNING_METHOD])
else:
raise NotImplementedError(
'Compression technique {} is not implemented'.format(k))
raise NotImplementedError('Compression technique {} is not implemented'.format(k))
# Replace the old module with the new one
recursive_setattr(model, module_name, new_module)
......@@ -195,10 +171,7 @@ def is_module_compressible(module, mpu=None):
isinstance(module, torch.nn.BatchNorm2d)
if mpu is not None:
ret = ret or isinstance(module,
mpu.RowParallelLinear) or isinstance(
module,
mpu.ColumnParallelLinear)
ret = ret or isinstance(module, mpu.RowParallelLinear) or isinstance(module, mpu.ColumnParallelLinear)
return ret
......@@ -225,11 +198,7 @@ def compression_preparation(model, compression_techinique_list, mpu):
return model
def fix_compression(model,
module_name,
compression_technique,
mask=None,
dim_reduction=False):
def fix_compression(model, module_name, compression_technique, mask=None, dim_reduction=False):
"""
Fix the compression technique of a module.
Args:
......@@ -243,17 +212,14 @@ def fix_compression(model,
# Here we can make things much simpler by just replacing the module
module = recursive_getattr(model, module_name)
for k, v in compression_technique.items():
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[
WEIGHT_QUANTIZE_ENABLED]:
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[WEIGHT_QUANTIZE_ENABLED]:
return module.fix_weight_quantization()
elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
return module.fix_sparse_pruning_helper()
elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
return module.fix_head_pruning_helper(mask,
v[HEAD_PRUNING_NUM_HEADS],
dim_reduction=dim_reduction)
return module.fix_head_pruning_helper(mask, v[HEAD_PRUNING_NUM_HEADS], dim_reduction=dim_reduction)
elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
......@@ -270,10 +236,9 @@ def convert_conv1d_to_linear(model, convert_type):
for name, module in c_model.named_modules():
if isinstance(module, convert_type):
old_module = recursive_getattr(c_model, name)
new_module = torch.nn.Linear(
old_module.weight.data.size(0),
old_module.weight.data.size(1),
bias=True if old_module.bias is not None else False)
new_module = torch.nn.Linear(old_module.weight.data.size(0),
old_module.weight.data.size(1),
bias=True if old_module.bias is not None else False)
new_module.weight.data = old_module.weight.data.t().contiguous()
if new_module.bias is not None:
new_module.bias.data = old_module.bias.data.view(-1)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .compress import get_module_name
from .constants import *
......@@ -10,6 +13,7 @@ class compression_scheduler():
'''
Used to schedule different compression methods
'''
def __init__(self, model, compression_config):
self.model = model
self.compression_config = compression_config
......@@ -38,22 +42,22 @@ class compression_scheduler():
}
exist_module_name = set()
shared_parameters = method_content[SHARED_PARAMETERS]
self.different_compression_methods[method][
TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
self.different_compression_methods[method][
SHARED_PARAMETERS] = shared_parameters
self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
module_name_list = []
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, self.model, key_word, exist_module_name, verbose=False)
module_name, exist_module_name = get_module_name(group_name,
self.model,
key_word,
exist_module_name,
verbose=False)
module_name_list.extend(module_name)
if module_name_list:
self.different_compression_methods[method][DIFFERENT_GROUPS].append([
group_name,
module_name_list,
method_parameters.copy().pop('params')
])
self.different_compression_methods[method][DIFFERENT_GROUPS].append(
[group_name, module_name_list,
method_parameters.copy().pop('params')])
def check_weight_quantization(self):
# check weight quantization
......@@ -69,8 +73,7 @@ class compression_scheduler():
module.weight_quantization_enabled = True
if not self.verbose[WEIGHT_QUANTIZATION]:
logger.info(
f'Weight quantization is enabled at step {self.training_steps}')
logger.info(f'Weight quantization is enabled at step {self.training_steps}')
self.weight_quantization_enabled = True
self.verbose[WEIGHT_QUANTIZATION] = True
......@@ -87,9 +90,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name)
module.activation_quantization_enabled = True
if not self.verbose[ACTIVATION_QUANTIZATION]:
logger.info(
f'Activation quantization is enabled at step {self.training_steps}'
)
logger.info(f'Activation quantization is enabled at step {self.training_steps}')
self.verbose[ACTIVATION_QUANTIZATION] = True
def check_sparse_pruning(self):
......@@ -105,8 +106,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name)
module.sparse_pruning_enabled = True
if not self.verbose[SPARSE_PRUNING]:
logger.info(
f'Sparse pruning is enabled at step {self.training_steps}')
logger.info(f'Sparse pruning is enabled at step {self.training_steps}')
self.verbose[SPARSE_PRUNING] = True
def check_head_pruning(self):
......@@ -154,8 +154,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name)
module.channel_pruning_enabled = True
if not self.verbose[CHANNEL_PRUNING]:
logger.info(
f'Channel pruning is enabled at step {self.training_steps}')
logger.info(f'Channel pruning is enabled at step {self.training_steps}')
self.verbose[CHANNEL_PRUNING] = True
def check_all_modules(self):
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from torch import autograd
......@@ -13,6 +16,7 @@ class TopKBinarizer(autograd.Function):
Implementation is inspired from:
https://github.com/yaozhewei/MLPruning
"""
@staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
"""
......@@ -59,6 +63,7 @@ class SymQuantizer(torch.autograd.Function):
"""
Symmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
......@@ -75,9 +80,8 @@ class SymQuantizer(torch.autograd.Function):
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None
and max_value is None) or (min_value is not None
and max_value is not None and num_groups == 1)
assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
......@@ -101,6 +105,7 @@ class AsymQuantizer(torch.autograd.Function):
"""
Asymmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
......@@ -118,9 +123,8 @@ class AsymQuantizer(torch.autograd.Function):
Quantized input
"""
assert (min_value is None
and max_value is None) or (min_value is not None
and max_value is not None and num_groups == 1)
assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
......@@ -131,9 +135,7 @@ class AsymQuantizer(torch.autograd.Function):
scale = (max_value - min_value) / q_range
zero_point = (min_value / scale).round() * scale
output = (
(input - zero_point) / scale).round().clamp(0,
q_range - 1) * scale + zero_point
output = ((input - zero_point) / scale).round().clamp(0, q_range - 1) * scale + zero_point
output = output.reshape(input_shape).contiguous()
return output
......@@ -147,6 +149,7 @@ class TernaryQuantizer(torch.autograd.Function):
"""
Ternary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
......@@ -187,6 +190,7 @@ class BinaryQuantizer(torch.autograd.Function):
"""
Binary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from datetime import timedelta
#############################################
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
from .utils import is_torch_elastic_compatible
......
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import json
from .constants import *
......@@ -43,77 +44,64 @@ class ElasticityConfig:
"version": 0.1
}
"""
def __init__(self, param_dict):
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
if self.enabled:
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
else:
raise ElasticityConfigError(
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
raise ElasticityConfigError(f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
if MICRO_BATCHES in param_dict:
self.micro_batches = param_dict[MICRO_BATCHES]
else:
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
else:
self.max_acceptable_batch_size = param_dict.get(
MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.max_acceptable_batch_size = param_dict.get(MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
if not isinstance(self.micro_batches, list):
raise ElasticityConfigError(
f"Elasticity expected value of {MICRO_BATCHES} to be a "
f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}"
)
f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}")
if not all(map(lambda m: isinstance(m, int), self.micro_batches)):
raise ElasticityConfigError(
f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
f"instead contains: f{self.micro_batches}")
raise ElasticityConfigError(f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
f"instead contains: f{self.micro_batches}")
if not all(map(lambda m: m > 0, self.micro_batches)):
raise ElasticityConfigError(
f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
f"instead contains: f{self.micro_batches}")
raise ElasticityConfigError(f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
f"instead contains: f{self.micro_batches}")
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError(
"Elasticity min/max gpus must be > 0, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
raise ElasticityConfigError("Elasticity min/max gpus must be > 0, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
if self.max_gpus < self.min_gpus:
raise ElasticityConfigError(
"Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
raise ElasticityConfigError("Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE,
MODEL_PARLLEL_SIZE_DEFAULT)
self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE, MODEL_PARLLEL_SIZE_DEFAULT)
if self.model_parallel_size < 1:
raise ElasticityConfigError(
"Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")
raise ElasticityConfigError("Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")
self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE,
NUM_GPUS_PER_NODE_DEFAULT)
self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1:
raise ElasticityConfigError(
"Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")
raise ElasticityConfigError("Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0:
raise ElasticityConfigError(
f"Elasticity min time needs to be >= 0: given {self.min_time}")
raise ElasticityConfigError(f"Elasticity min time needs to be >= 0: given {self.min_time}")
self.version = param_dict.get(VERSION, VERSION_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
def repr(self):
return self.__dict__
......
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
#########################################
# Elasticity
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from typing import Any, Dict, Optional, Tuple
......@@ -23,6 +26,7 @@ import subprocess
class DSElasticAgent(LocalElasticAgent):
def __init__(
self,
spec: WorkerSpec,
......@@ -35,9 +39,7 @@ class DSElasticAgent(LocalElasticAgent):
self.ds_env = env
@staticmethod
def _set_master_addr_port(store: Store,
master_addr: Optional[str],
master_port: Optional[int]):
def _set_master_addr_port(store: Store, master_addr: Optional[str], master_port: Optional[int]):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
......@@ -82,8 +84,7 @@ class DSElasticAgent(LocalElasticAgent):
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING",
str(1)),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)),
}
worker_env_ds.update(worker_env_elastic)
if "OMP_NUM_THREADS" in os.environ:
......@@ -120,8 +121,7 @@ class DSElasticAgent(LocalElasticAgent):
spec = self._worker_group.spec
role = spec.role
log.info(
f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
......@@ -136,13 +136,10 @@ class DSElasticAgent(LocalElasticAgent):
state = run_result.state
self._worker_group.state = state
expire_time = datetime.utcnow() - (
rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
_dead_nodes = [
node for node,
last_heartbeat in
rdzv_handler._state_holder.state.last_heartbeats.items()
node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items()
if last_heartbeat < expire_time
]
......@@ -150,21 +147,16 @@ class DSElasticAgent(LocalElasticAgent):
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
log.info(f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.")
self._exit_barrier()
return run_result
elif state in {
WorkerState.UNHEALTHY,
WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
log.info(f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
self._remaining_restarts -= 1
# rdzv_handler._state_holder.state.restart = False
self._restart_workers(self._worker_group)
......
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import json
import numpy as np
......@@ -17,44 +19,8 @@ from ..utils import logger
# Thirty eight smallest highly composite numbers. The list should
# be enough to support up to 720K batch size.
HCN_LIST = [
1,
2,
4,
6,
12,
24,
36,
48,
60,
120,
180,
240,
360,
720,
840,
1260,
1680,
2520,
5040,
7560,
10080,
15120,
20160,
25200,
27720,
45360,
50400,
55440,
83160,
110880,
166320,
221760,
277200,
332640,
498960,
554400,
665280,
720720
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680, 2520, 5040, 7560, 10080, 15120, 20160,
25200, 27720, 45360, 50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280, 720720
]
......@@ -94,11 +60,7 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
return valid_gpus
def get_best_candidates(candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger):
def get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus, prefer_larger):
max_valid_gpus = 0
valid_gpus = None
......@@ -106,15 +68,11 @@ def get_best_candidates(candidate_batch_sizes,
for batch_size in candidate_batch_sizes:
current_valid_gpus = get_valid_gpus(batch_size,
micro_batches,
min_gpus,
max_gpus)
current_valid_gpus = get_valid_gpus(batch_size, micro_batches, min_gpus, max_gpus)
if (len(current_valid_gpus) > max_valid_gpus
or (len(current_valid_gpus) == max_valid_gpus and
((prefer_larger and batch_size > final_batch_size) or
(not prefer_larger and batch_size < final_batch_size)))):
if (len(current_valid_gpus) > max_valid_gpus or (len(current_valid_gpus) == max_valid_gpus and
((prefer_larger and batch_size > final_batch_size) or
(not prefer_larger and batch_size < final_batch_size)))):
max_valid_gpus = len(current_valid_gpus)
valid_gpus = current_valid_gpus
final_batch_size = batch_size
......@@ -157,15 +115,10 @@ def _get_compatible_gpus_v01(micro_batches,
base_list.extend(micro_batches)
base_list.append(lcm)
candidate_batch_sizes = get_candidate_batch_sizes(base_list,
max_acceptable_batch_size)
candidate_batch_sizes = get_candidate_batch_sizes(base_list, max_acceptable_batch_size)
final_batch_size, valid_gpus = get_best_candidates(
candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger)
final_batch_size, valid_gpus = get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus,
prefer_larger)
return final_batch_size, valid_gpus
......@@ -203,11 +156,12 @@ def _get_compatible_gpus_v02(micro_batches,
dp_size_per_node = num_gpus_per_node // model_parallel_size
final_batch_size, valid_world_size = _get_compatible_gpus_v01(micro_batches,
int(max_acceptable_batch_size/dp_size_per_node),
int(min_gpus/num_gpus_per_node),
int(max_gpus/num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
prefer_larger=prefer_larger)
final_batch_size, valid_world_size = _get_compatible_gpus_v01(
micro_batches,
int(max_acceptable_batch_size / dp_size_per_node),
int(min_gpus / num_gpus_per_node),
int(max_gpus / num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
prefer_larger=prefer_larger)
final_batch_size = int(final_batch_size) * dp_size_per_node
valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size]
......@@ -256,38 +210,27 @@ def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
Ensure the resource scheduler saw the same elastic config we are using at runtime
"""
if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
scheduler_elastic_config_dict = json.loads(
os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config_dict = json.loads(os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
raise ElasticityConfigError(
err_str.format('max_acceptable_batch_size',
scheduler_elastic_config.max_acceptable_batch_size,
'max_acceptable_batch_size',
runtime_elastic_config.max_acceptable_batch_size))
err_str.format('max_acceptable_batch_size', scheduler_elastic_config.max_acceptable_batch_size,
'max_acceptable_batch_size', runtime_elastic_config.max_acceptable_batch_size))
if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
raise ElasticityConfigError(
err_str.format('micro_batches',
scheduler_elastic_config.micro_batches,
'micro_batches',
err_str.format('micro_batches', scheduler_elastic_config.micro_batches, 'micro_batches',
runtime_elastic_config.micro_batches))
if runtime_elastic_config.version != scheduler_elastic_config.version:
raise ElasticityConfigError(
err_str.format('version',
scheduler_elastic_config.version,
'version',
runtime_elastic_config.version))
err_str.format('version', scheduler_elastic_config.version, 'version', runtime_elastic_config.version))
else:
logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
"guarantee resource scheduler will scale this job using compatible GPU counts.")
def compute_elastic_config(ds_config: dict,
target_deepspeed_version: str,
world_size=0,
return_microbatch=False):
def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0, return_microbatch=False):
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below)
DeepSpeed will compute a total train batch size corresponding valid GPU count list that
provides a high level of elasticity. Elasticity in this case means we are safe to scale
......@@ -397,8 +340,7 @@ def compute_elastic_config(ds_config: dict,
# ensure batch size is int dtype
final_batch_size = int(final_batch_size)
else:
raise NotImplementedError(
f"Unable to find elastic logic for version: {elastic_config.version}")
raise NotImplementedError(f"Unable to find elastic logic for version: {elastic_config.version}")
logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}")
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import deepspeed
......@@ -48,8 +51,7 @@ def op_report(verbose=True):
dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible(verbose) else no
is_installed = installed if installed_ops[op_name] else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) -
(len(is_installed) - color_len))
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
......@@ -68,9 +70,7 @@ def nvcc_version():
if cuda_home is None:
return f"{RED} [FAIL] cannot find CUDA_HOME via torch.utils.cpp_extension.CUDA_HOME={torch.utils.cpp_extension.CUDA_HOME} {END}"
try:
output = subprocess.check_output([cuda_home + "/bin/nvcc",
"-V"],
universal_newlines=True)
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
except FileNotFoundError:
return f"{RED} [FAIL] nvcc missing {END}"
output_split = output.split()
......@@ -82,32 +82,18 @@ def nvcc_version():
def debug_report():
max_dots = 33
report = [
("torch install path",
torch.__path__),
("torch version",
torch.__version__),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
)
]
report = [("torch install path", torch.__path__), ("torch version", torch.__version__),
("deepspeed install path", deepspeed.__path__),
("deepspeed info", f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}")]
if get_accelerator().device_name() == 'cuda':
hip_version = getattr(torch.version, "hip", None)
report.extend([("torch cuda version",
torch.version.cuda),
("torch hip version",
hip_version),
("nvcc version",
(None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}"))])
report.extend([("torch cuda version", torch.version.cuda), ("torch hip version", hip_version),
("nvcc version", (None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}"))
])
else:
report.extend([("deepspeed wheel compiled w.",
f"torch {torch_info['version']} ")])
report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")])
print("DeepSpeed general environment info:")
for name, value in report:
......@@ -116,15 +102,10 @@ def debug_report():
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--hide_operator_status',
action='store_true',
help=
'Suppress display of installation and compatibility statuses of DeepSpeed operators. '
)
parser.add_argument('--hide_errors_and_warnings',
parser.add_argument('--hide_operator_status',
action='store_true',
help='Suppress warning and error messages.')
help='Suppress display of installation and compatibility statuses of DeepSpeed operators. ')
parser.add_argument('--hide_errors_and_warnings', action='store_true', help='Suppress warning and error messages.')
args = parser.parse_args()
return args
......@@ -137,8 +118,7 @@ def main(hide_operator_status=False, hide_errors_and_warnings=False):
def cli_main():
args = parse_arguments()
main(hide_operator_status=args.hide_operator_status,
hide_errors_and_warnings=args.hide_errors_and_warnings)
main(hide_operator_status=args.hide_operator_status, hide_errors_and_warnings=args.hide_errors_and_warnings)
if __name__ == "__main__":
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
try:
# This is populated by setup.py
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .engine import InferenceEngine
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
......@@ -12,10 +15,12 @@ from enum import Enum
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8"
# bf16 not supported
# bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
......@@ -192,6 +197,11 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
This can be passed through the json config too.
"""
set_empty_params: bool = False
"""
specifying whether the inference-module is created with empty or real Tensor
"""
save_mp_checkpoint_path: str = None
"""
The path for which we want to save the loaded model with a checkpoint. This
......@@ -222,9 +232,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
replace_method: str = Field(
"auto",
deprecated=True,
deprecated_msg=
"This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
)
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
injection_policy: Dict = Field(None, alias="injection_dict")
"""
......@@ -235,9 +243,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
injection_policy_tuple: tuple = None
""" TODO: Add docs """
config: Dict = Field(
None,
alias="args") # todo: really no need for this field if we can refactor
config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens")
"""
......@@ -246,6 +252,16 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
to the required token-length required for your use-case.
"""
min_out_tokens: int = Field(1, alias="min_tokens")
"""
This argument communicates to the runtime the minimum number of tokens you
expect you will need to generate. This will cause the runtime to error
if it unable to provide this and provide context on the memory pressure
rather than seg-faulting or providing corrupted output.
"""
transposed_mode: bool = Field(False, alias="transposed_mode")
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
"""
Desired model parallel size, default is 1 meaning no model parallelism.
......@@ -254,18 +270,10 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"""
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None,
alias="expert_group",
deprecated=True,
new_param="moe.ep_group")
ep_mp_group: object = Field(None,
alias="expert_mp_group",
deprecated=True,
new_param="moe.ep_mp_group")
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
deprecated=True,
new_param="moe.type")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
@validator("moe")
def moe_backward_compat(cls, field_value, values):
......
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import time
import os
......@@ -32,6 +34,58 @@ from torch import nn
INFERENCE_MODEL_TIMER = "model-forward-inference"
def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
class InferenceEngine(Module):
inference_mp_group = None
inference_ep_group = None
......@@ -80,13 +134,18 @@ class InferenceEngine(Module):
self.checkpoint_engine = TorchCheckpointEngine()
quantization_setting = None
self._init_quantization_setting(
quantization_setting
) # todo: update with the new quant config for weight quant
quantization_setting) # todo: update with the new quant config for weight quant
self.model_profile_enabled = False
self._model_times = []
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self.remove_mask_prepare_for_bloom()
if not self.injection_dict and config.replace_with_kernel_inject:
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self.remove_mask_prepare_for_bloom()
if self.injection_dict or not config.replace_with_kernel_inject:
# This is a hack to redefine the alibi func due to TP
if config.tensor_parallel.tp_size > 1:
self.build_alibi_tensor()
if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
......@@ -100,8 +159,7 @@ class InferenceEngine(Module):
self._convert_to_dtype(config)
if self.mpu:
config.tensor_parallel.tp_size = dist.get_world_size(
group=self.mpu.get_model_parallel_group())
config.tensor_parallel.tp_size = dist.get_world_size(group=self.mpu.get_model_parallel_group())
self.mp_group = self.mpu.get_model_parallel_group()
elif config.tensor_parallel.tp_size > 1:
self._create_model_parallel_group(config)
......@@ -149,8 +207,7 @@ class InferenceEngine(Module):
self.module.to(device)
if config.tensor_parallel.tp_size > 1:
_rng_state = get_accelerator().get_rng_state().to(
get_accelerator().current_device_name())
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0)
get_accelerator().set_rng_state(_rng_state.cpu())
......@@ -172,15 +229,18 @@ class InferenceEngine(Module):
# todo: remove this once all the config dicts are centralized from top level pydantic config
def _get_model_config_generate(self, config):
# this is being passed to replace_transformer_layer(config=self.user_model_config_dict)
self.config = getattr(self.module,
'config',
None) if config.config is None else config.config
self.config = getattr(self.module, 'config', None) if config.config is None else config.config
def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, '_prepare_attn_mask'):
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
def build_alibi_tensor(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, 'build_alibi_tensor'):
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
def _pre_forward_hook(self, module, *inputs, **kwargs):
if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).start()
......@@ -223,8 +283,7 @@ class InferenceEngine(Module):
num_ep_groups = dist.get_world_size() // moe_ep_size
for i in range(num_ep_groups):
ep_cnt = i * moe_ep_size
size = dist.get_world_size(
) if moe_ep_size > dist.get_world_size() else moe_ep_size
size = dist.get_world_size() if moe_ep_size > dist.get_world_size() else moe_ep_size
ranks = list(range(ep_cnt, ep_cnt + size))
_ep_group = dist.new_group(ranks)
if dist.get_rank() in ranks:
......@@ -234,9 +293,7 @@ class InferenceEngine(Module):
num_expert_mp_groups = dist.get_world_size() // num_ep_groups
expert_mp_size = dist.get_world_size() // moe_ep_size
for i in range(num_expert_mp_groups):
expert_mp_comm_ranks = [
i + nr * moe_ep_size for nr in range(expert_mp_size)
]
expert_mp_comm_ranks = [i + nr * moe_ep_size for nr in range(expert_mp_size)]
_expert_mp_group = dist.new_group(expert_mp_comm_ranks)
if dist.get_rank() in expert_mp_comm_ranks:
self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
......@@ -253,65 +310,48 @@ class InferenceEngine(Module):
log_dist(
f"quantize_bits = {self.quantize_bits} "
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}",
[0])
f"quantize_groups = {self.quantize_groups}", [0])
# TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self._config.tensor_parallel.tp_size,
int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(
f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}"
)
if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self._config.checkpoint is not None and not isinstance(
self._config.checkpoint,
(str,
dict)):
raise ValueError(
f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}"
)
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
supported_dtypes = [None, torch.half, torch.int8, torch.float]
if self._config.dtype not in supported_dtypes:
raise ValueError(
f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError(
f"injection_dict must be None or a dict, got: {self.injection_dict}")
raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group,
mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
error_msgs = []
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if 'query_key_value' in prefix:
module.weight = self.mp_replace.qkv_copy(
module.weight.data,
state_dict[prefix + 'weight'])
module.weight = self.mp_replace.qkv_copy(module.weight.data, state_dict[prefix + 'weight'])
else:
module.weight = self.mp_replace.copy(module.weight.data,
state_dict[prefix + 'weight'])
module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
module.norm.weight = self.mp_replace.copy(module.norm.weight.data,
state_dict[prefix + 'weight'])
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
module.norm.bias = self.mp_replace.copy(module.norm.bias,
state_dict[prefix + 'bias'])
module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
else:
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
......@@ -331,45 +371,32 @@ class InferenceEngine(Module):
checking_key = prefix + name + '.'
if not any(checking_key in item for item in self.key_list):
continue
if len(list(child.parameters())) > 0 and list(
child.parameters())[0].numel() == 0:
if len(list(child.parameters())) > 0 and list(child.parameters())[0].numel() == 0:
if len(child.weight.ds_shape) == 1:
child = Normalize(dim=child.weight.ds_shape[-1],
dtype=child.weight.dtype,
eps=child.eps)
child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
setattr(module, name, child)
load(child, self.sd, prefix + name + '.')
else:
load_module_recursive(child,
prefix if level == 0 else prefix + name + '.',
level + 1)
load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1)
load_module_recursive(r_module)
def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
checkpoint = SDLoaderFactory.get_sd_loader_json(
checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
checkpoint = SDLoaderFactory.get_sd_loader_json(checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
generic_injection(self.module,
fp16=(config.dtype == torch.half)
or (config.dtype == torch.int8),
fp16=(config.dtype == torch.half) or (config.dtype == torch.int8),
enable_cuda_graph=config.enable_cuda_graph)
if isinstance(self.module, torch.nn.Module):
# config is our DeepSpeedInferenceConfig and self.config is the HF model config
replace_transformer_layer(client_module,
self.module,
checkpoint,
config,
self.config)
replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
tag,
mp_placeholder="*")
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
import glob
ckpt_files = glob.glob(ckpt_file_pattern)
......@@ -392,8 +419,7 @@ class InferenceEngine(Module):
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
is_pipe_parallel = isinstance(self.module, PipelineModule)
if is_pipe_parallel:
raise RuntimeError(
'pipeline parallelism is currently not supported in inference.')
raise RuntimeError('pipeline parallelism is currently not supported in inference.')
if not isinstance(load_dir, dict) and os.path.isdir(load_dir):
if tag is None:
latest_path = os.path.join(load_dir, "latest")
......@@ -404,8 +430,7 @@ class InferenceEngine(Module):
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)
if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
......@@ -416,19 +441,18 @@ class InferenceEngine(Module):
for i in range(1, len(sd_loader)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i],
map_location=get_accelerator().device_name())
self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size,
mp_rank,
is_pipe_parallel=is_pipe_parallel,
quantize=(self._config.dtype is torch.int8),
quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping)
mp_rank,
is_pipe_parallel=is_pipe_parallel,
quantize=(self._config.dtype is torch.int8),
quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping)
self.quantization_scales, self.quantize_merge_count = quantize_config
......@@ -438,21 +462,20 @@ class InferenceEngine(Module):
old_moe_load = False
if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True
DeepSpeedEngine.load_moe_state_dict(
load_dir,
tag,
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
checkpoint_engine=self.checkpoint_engine)
self.module.load_state_dict(
state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
DeepSpeedEngine.load_moe_state_dict(load_dir,
tag,
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
checkpoint_engine=self.checkpoint_engine)
self.module.load_state_dict(state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
def _choose_module_key(self, sd):
assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert not ('module' in sd
and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
if 'module' in sd:
return 'module'
......@@ -465,10 +488,8 @@ class InferenceEngine(Module):
if False: #config.dtype is torch.int8 and self.quantization_scales is None:
quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
model, self.quantization_scales = quantizer.model_quantize(self.module,
self.injection_dict,
self.quantize_bits,
self.quantize_groups)
model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict,
self.quantize_bits, self.quantize_groups)
elif config.dtype == torch.half:
self.module.half()
elif config.dtype == torch.bfloat16:
......@@ -509,11 +530,10 @@ class InferenceEngine(Module):
assert self.model_profile_enabled, "model profiling is not enabled"
model_times = self._model_times
if self._config.enable_cuda_graph and len(self._model_times) == 0:
raise ValueError(
"Model times are empty and cuda graph is enabled. If "
"this is a GPT-style model this combo is not supported. If this is a "
"BERT-style model this is a bug, please report it. "
f"Model type is: {type(self.module)}")
raise ValueError("Model times are empty and cuda graph is enabled. If "
"this is a GPT-style model this combo is not supported. If this is a "
"BERT-style model this is a bug, please report it. "
f"Model type is: {type(self.module)}")
self._model_times = []
return model_times
......@@ -532,8 +552,7 @@ class InferenceEngine(Module):
for name in module.__dict__.keys():
sub_module = getattr(module, name)
if self._module_match(sub_module) and hasattr(sub_module,
"enable_cuda_graph"):
if self._module_match(sub_module) and hasattr(sub_module, "enable_cuda_graph"):
sub_module_cuda_graph = True
return sub_module_cuda_graph
......@@ -546,13 +565,11 @@ class InferenceEngine(Module):
**kwargs: variable length keyword arguments
"""
start = None
if self.model_profile_enabled and get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph:
if self.model_profile_enabled and get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph:
get_accelerator().synchronize()
start = time.time()
if get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
......@@ -580,9 +597,7 @@ class InferenceEngine(Module):
num_beams = kwargs["num_beams"]
if num_beams > 1:
raise NotImplementedError(
"DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
)
raise NotImplementedError("DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506")
return self.module.generate(*inputs, **kwargs)
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright 2020 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024
......
# Copyright 2020 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed launcher, this is similar to torch's distributed.launch but supports
additional features such as arbitrary gpu exclusion.
......@@ -16,6 +19,7 @@ import base64
import time
import signal
import psutil
import distutils
from collections import defaultdict
from typing import Dict
from argparse import ArgumentParser, REMAINDER
......@@ -52,10 +56,7 @@ def parse_args():
help="Master node (rank 0)'s free port that needs to "
"be used for communication during distributed "
"training")
parser.add_argument("--world_info",
default="None",
type=str,
help="world info base64 encoded dictionary")
parser.add_argument("--world_info", default="None", type=str, help="world info base64 encoded dictionary")
parser.add_argument("--module",
action="store_true",
......@@ -68,19 +69,11 @@ def parse_args():
help="Skip prepending the training script with "
"'python' - just execute it directly.")
parser.add_argument("--enable_elastic_training",
action="store_true",
help="Enable elastic training support.")
parser.add_argument("--enable_elastic_training", action="store_true", help="Enable elastic training support.")
parser.add_argument("--min_elastic_nodes",
type=int,
default=-1,
help="Min number of nodes in elastic training.")
parser.add_argument("--min_elastic_nodes", type=int, default=-1, help="Min number of nodes in elastic training.")
parser.add_argument("--max_elastic_nodes",
type=int,
default=-1,
help="Max number of nodes in elastic training.")
parser.add_argument("--max_elastic_nodes", type=int, default=-1, help="Max number of nodes in elastic training.")
parser.add_argument("--no_local_rank",
action="store_true",
......@@ -92,11 +85,22 @@ def parse_args():
default=0,
help="main launching process pid, for internal pid tracking")
parser.add_argument(
"--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument("--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument("--bind_cores_to_rank",
action="store_true",
help="Bind each rank to different cores of the host. "
"This improves host efficiency especially for CPU backend")
parser.add_argument("--bind_core_list",
type=str,
default=None,
help="List of cores to bind to with comma separated list of "
"numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
"specified, all cores on system would be used rank binding")
# positional
parser.add_argument("training_script",
......@@ -126,6 +130,89 @@ def terminate_process_tree(pid):
p.kill()
def parse_range(rng):
try:
value = int(rng)
return range(value, value + 1)
except ValueError:
# value is not a single number
parts = rng.split('-')
if len(parts) != 2:
raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" %
(rng, ))
start = int(parts[0])
end = int(parts[1])
if start > end:
raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, ))
return range(start, end + 1)
# parse comma and dash separated range list into list
# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6]
# rules:
# 1. Range list numser be comma sepeaated, each item are either a single number,
# or a range marked by two numbers (both number are included in the range)
# 2. Sub ranges must be in ascend order and not overlap with each other
# 3. No space in the range expression
def parse_range_list(range_str):
number_list = []
last = -1
range_list = range_str.split(',')
for sub_range in range_list:
sub_number_list = parse_range(sub_range)
if sub_number_list[0] <= last:
raise ValueError(
"Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" %
(range_str, ))
last = sub_number_list[-1]
number_list.extend(sub_number_list)
return number_list
# return a list of list for cores to numa mapping
# [
# [ cores for numa 0 ]
# [ cores belong to numa 1 ]
# ...
# ]
def get_numa_cores():
ret = []
output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8")
lines = output.split('\n')
for line in lines:
if line.startswith('available:'):
num_numas = int(line.split(' ')[1])
break
for numa in range(num_numas):
for line in lines:
if line.startswith(f'node {numa} cpus:'):
cores = line.split(' ')[3:]
ret.append([int(core) for core in cores])
return ret
def check_for_numactl_pkg():
libs = dict(
dpkg=["-l", "numactl", "apt"],
pacman=["-Q", "numactl", "pacman"],
rpm=["-q", "numactl", "yum"],
)
found = False
for pkgmgr, data in libs.items():
flag, lib, tool = data
path = distutils.spawn.find_executable(pkgmgr)
if path is not None:
cmd = f"{pkgmgr} {flag} {lib}"
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
if result.wait() == 0:
found = True
else:
print(f"please install the {lib} package with {tool}")
break
return found
def main():
args = parse_args()
current_env = os.environ.copy()
......@@ -145,9 +232,7 @@ def main():
local_node = node_list[args.node_rank]
local_gpu_ids = world_info[local_node]
num_local_procs = len(local_gpu_ids)
logger.info(
f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}"
)
logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}")
global_rank_mapping = defaultdict(list)
curr_global_rank = 0
......@@ -193,8 +278,7 @@ def main():
lines = file.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
if line.startswith('export FC_TASKROLE_NAME') or line.startswith(
'export FC_TASK_INDEX'):
if line.startswith('export FC_TASKROLE_NAME') or line.startswith('export FC_TASK_INDEX'):
key_val = line.split()[1]
key, val = key_val.split('=')
current_env[key] = val
......@@ -206,17 +290,13 @@ def main():
if args.enable_each_rank_log != "None":
# prepare the log path and the file name prefix
if os.path.isfile(args.enable_each_rank_log):
raise ValueError(
f"{args.enable_each_rank_log} should not be a file, it should be a directory."
)
raise ValueError(f"{args.enable_each_rank_log} should not be a file, it should be a directory.")
if not os.path.exists(args.enable_each_rank_log):
try:
os.makedirs(args.enable_each_rank_log)
except Exception as e:
print(e)
raise ValueError(
f"unable to create directory {args.enable_each_rank_log} for each rank log."
)
raise ValueError(f"unable to create directory {args.enable_each_rank_log} for each rank log.")
log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime())
for local_rank in range(0, num_local_procs):
......@@ -227,8 +307,43 @@ def main():
# spawn the processes
cmd = []
if args.bind_cores_to_rank:
check_for_numactl_pkg()
if 'KMP_AFFINITY' in os.environ.keys():
raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl "
"because it interfere with how many CPU cores numactl can set. "
"Unset KMP_AFFINITY before launching deepspeed.\n\n"
"\t$ unset KMP_AFFINITY\n"
"\t$ deepspeed <deepspeed command parameters>")
if args.bind_core_list != None:
core_list = parse_range_list(args.bind_core_list)
total_cores = len(core_list)
else:
total_cores = psutil.cpu_count(logical=False)
core_list = range(total_cores)
cores_per_rank = total_cores // num_local_procs
assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank"
core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)]
current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}"
cmd.append("numactl")
# check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter
numa_cores = get_numa_cores()
num_numas = len(numa_cores)
for i in range(num_numas):
if set(core_list_for_rank) <= set(numa_cores[i]):
cmd.append("-m")
cmd.append(f"{i}")
break
cmd.append("-C")
core_list_str = f"{core_list_for_rank[0]}"
for core_id in core_list_for_rank[1:]:
core_list_str = f"{core_list_str},{core_id}"
cmd.append(f"{core_list_str}")
if not args.no_python:
cmd = [sys.executable, "-u"]
cmd.append(sys.executable)
cmd.append("-u")
if args.module:
cmd.append("-m")
else:
......@@ -242,13 +357,9 @@ def main():
cmd += args.training_script_args
if args.enable_each_rank_log != "None":
log_file = os.path.join(args.enable_each_rank_log,
f"{log_name_prefix}_rank{dist_rank}.log")
log_file = os.path.join(args.enable_each_rank_log, f"{log_name_prefix}_rank{dist_rank}.log")
log_fd = open(log_file, 'w')
process = subprocess.Popen(cmd,
env=current_env,
stdout=log_fd,
stderr=log_fd)
process = subprocess.Popen(cmd, env=current_env, stdout=log_fd, stderr=log_fd)
else:
process = subprocess.Popen(cmd, env=current_env)
......@@ -264,7 +375,7 @@ def main():
args.min_elastic_nodes = 1
if args.max_elastic_nodes == -1:
args.max_elastic_nodes = args.nnodes
assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive"
assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0, "Max and Min nodes should be positive"
current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
......@@ -287,8 +398,7 @@ def main():
# Creating config for rendezvous class
rdzv_parameters = RendezvousParameters(backend='c10d',
endpoint=args.master_addr + ":" +
str(args.master_port),
endpoint=args.master_addr + ":" + str(args.master_port),
run_id=run_id,
min_nodes=args.min_elastic_nodes,
max_nodes=args.max_elastic_nodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment