Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .constants import * from .constants import *
import copy import copy
...@@ -36,9 +39,7 @@ def get_layer_reduction(param_dict): ...@@ -36,9 +39,7 @@ def get_layer_reduction(param_dict):
def get_layer_reduction_enabled(param_dict): def get_layer_reduction_enabled(param_dict):
if LAYER_REDUCTION in param_dict.keys(): if LAYER_REDUCTION in param_dict.keys():
return get_scalar_param(param_dict[LAYER_REDUCTION], return get_scalar_param(param_dict[LAYER_REDUCTION], LAYER_REDUCTION_ENABLED, LAYER_REDUCTION_ENABLED_DEFAULT)
LAYER_REDUCTION_ENABLED,
LAYER_REDUCTION_ENABLED_DEFAULT)
else: else:
return False return False
...@@ -70,7 +71,8 @@ def get_weight_quantization(param_dict): ...@@ -70,7 +71,8 @@ def get_weight_quantization(param_dict):
output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict) output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
# each sub-groups # each sub-groups
if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]: if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified" assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict) output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
return output return output
...@@ -79,51 +81,38 @@ def get_weight_quantization_shared_parameters(param_dict): ...@@ -79,51 +81,38 @@ def get_weight_quantization_shared_parameters(param_dict):
output = {} output = {}
if SHARED_PARAMETERS in param_dict.keys(): if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS] sub_param_dict = param_dict[SHARED_PARAMETERS]
output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param( output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ENABLED,
sub_param_dict, WEIGHT_QUANTIZE_ENABLED_DEFAULT)
WEIGHT_QUANTIZE_ENABLED, output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_KERNEL,
WEIGHT_QUANTIZE_ENABLED_DEFAULT) WEIGHT_QUANTIZE_KERNEL_DEFAULT)
output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param( output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
sub_param_dict, WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
WEIGHT_QUANTIZE_KERNEL, output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_KERNEL_DEFAULT) WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param( output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_VERBOSE,
sub_param_dict, WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
WEIGHT_QUANTIZE_SCHEDULE_OFFSET, output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_TYPE,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_VERBOSE,
WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict,
WEIGHT_QUANTIZE_TYPE,
WEIGHT_QUANTIZE_TYPE_DEFAULT) WEIGHT_QUANTIZE_TYPE_DEFAULT)
output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param( output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(sub_param_dict,
sub_param_dict, WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT) assert output[WEIGHT_QUANTIZE_TYPE] in [
assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]" WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC
output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param( ], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
sub_param_dict, output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ROUNDING,
WEIGHT_QUANTIZE_ROUNDING, WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
WEIGHT_QUANTIZE_ROUNDING_DEFAULT) assert output[WEIGHT_QUANTIZE_ROUNDING] in [
assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]" WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING
], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys(): if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param( output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT) WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param( output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_CHANGE_RATIO,
WEIGHT_QUANTIZE_CHANGE_RATIO,
WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT) WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
else: else:
output[ output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
else: else:
output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
...@@ -133,8 +122,7 @@ def get_weight_quantization_shared_parameters(param_dict): ...@@ -133,8 +122,7 @@ def get_weight_quantization_shared_parameters(param_dict):
output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
output[ output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
return output return output
...@@ -144,27 +132,21 @@ def get_weight_quantization_different_groups(param_dict): ...@@ -144,27 +132,21 @@ def get_weight_quantization_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS] sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict): def get_params(name, group_dict):
assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}" assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(
assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}" ), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param( assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(
group_dict, ), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
WEIGHT_QUANTIZATION_PERIOD, group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(group_dict, WEIGHT_QUANTIZATION_PERIOD,
WEIGHT_QUANTIZATION_PERIOD_DEFAULT) WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
return group_dict return group_dict
for k, v in sub_param_dict.items(): for k, v in sub_param_dict.items():
output[k] = {} output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
k, output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k], sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output return output
...@@ -172,19 +154,15 @@ def get_weight_quantization_different_groups(param_dict): ...@@ -172,19 +154,15 @@ def get_weight_quantization_different_groups(param_dict):
def get_activation_quantization(param_dict): def get_activation_quantization(param_dict):
output = {} output = {}
if ACTIVATION_QUANTIZATION not in param_dict.keys(): if ACTIVATION_QUANTIZATION not in param_dict.keys():
param_dict[ACTIVATION_QUANTIZATION] = { param_dict[ACTIVATION_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
SHARED_PARAMETERS: {},
DIFFERENT_GROUPS: {}
}
sub_param_dict = param_dict[ACTIVATION_QUANTIZATION] sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
# shared parameters # shared parameters
output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters( output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(sub_param_dict)
sub_param_dict)
# each sub-groups # each sub-groups
if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]: if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified" assert DIFFERENT_GROUPS in sub_param_dict.keys(
output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups( ), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
sub_param_dict) output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(sub_param_dict)
return output return output
...@@ -192,30 +170,26 @@ def get_activation_quantization_shared_parameters(param_dict): ...@@ -192,30 +170,26 @@ def get_activation_quantization_shared_parameters(param_dict):
output = {} output = {}
if SHARED_PARAMETERS in param_dict.keys(): if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS] sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param( output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZATION_ENABLED,
sub_param_dict, ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
ACTIVATION_QUANTIZATION_ENABLED, output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_TYPE,
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT) ACTIVATION_QUANTIZE_TYPE_DEFAULT)
output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param( assert output[ACTIVATION_QUANTIZE_TYPE] in [
sub_param_dict, ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC
ACTIVATION_QUANTIZE_TYPE, ], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
ACTIVATION_QUANTIZE_TYPE_DEFAULT) output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_RANGE,
assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]" ACTIVATION_QUANTIZE_RANGE_DEFAULT)
output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param( assert output[ACTIVATION_QUANTIZE_RANGE] in [
sub_param_dict, ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC
ACTIVATION_QUANTIZE_RANGE, ], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
ACTIVATION_QUANTIZE_RANGE_DEFAULT) output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict,
assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]" ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param( ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
sub_param_dict,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
else: else:
output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
output[ output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
return output return output
...@@ -224,22 +198,17 @@ def get_activation_quantization_different_groups(param_dict): ...@@ -224,22 +198,17 @@ def get_activation_quantization_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS] sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict): def get_params(name, group_dict):
assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}" assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(
), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
return group_dict return group_dict
for k, v in sub_param_dict.items(): for k, v in sub_param_dict.items():
output[k] = {} output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
k, output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k], sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output return output
...@@ -253,7 +222,8 @@ def get_sparse_pruning(param_dict): ...@@ -253,7 +222,8 @@ def get_sparse_pruning(param_dict):
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict) output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups # each sub-groups
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]: if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified" assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict) output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output return output
...@@ -262,18 +232,15 @@ def get_sparse_pruning_shared_parameters(param_dict): ...@@ -262,18 +232,15 @@ def get_sparse_pruning_shared_parameters(param_dict):
output = {} output = {}
if SHARED_PARAMETERS in param_dict.keys(): if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS] sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param( output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
sub_param_dict, SPARSE_PRUNING_ENABLED_DEFAULT)
SPARSE_PRUNING_ENABLED, output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT) SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]" assert output[SPARSE_PRUNING_METHOD] in [
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK
sub_param_dict, ], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
SPARSE_PRUNING_SCHEDULE_OFFSET, output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT) SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else: else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
...@@ -286,22 +253,17 @@ def get_sparse_pruning_different_groups(param_dict): ...@@ -286,22 +253,17 @@ def get_sparse_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS] sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict): def get_params(name, group_dict):
assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}" assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
return group_dict return group_dict
for k, v in sub_param_dict.items(): for k, v in sub_param_dict.items():
output[k] = {} output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
k, output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k], sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output return output
...@@ -315,7 +277,8 @@ def get_row_pruning(param_dict): ...@@ -315,7 +277,8 @@ def get_row_pruning(param_dict):
output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict) output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
# each sub-groups # each sub-groups
if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]: if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified" assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict) output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
return output return output
...@@ -324,17 +287,14 @@ def get_row_pruning_shared_parameters(param_dict): ...@@ -324,17 +287,14 @@ def get_row_pruning_shared_parameters(param_dict):
output = {} output = {}
if SHARED_PARAMETERS in param_dict.keys(): if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS] sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, ROW_PRUNING_ENABLED,
ROW_PRUNING_ENABLED,
ROW_PRUNING_ENABLED_DEFAULT) ROW_PRUNING_ENABLED_DEFAULT)
output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict, output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict, ROW_PRUNING_METHOD, ROW_PRUNING_METHOD_DEFAULT)
ROW_PRUNING_METHOD, assert output[ROW_PRUNING_METHOD] in [
ROW_PRUNING_METHOD_DEFAULT) ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK
assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]" ], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, ROW_PRUNING_SCHEDULE_OFFSET,
sub_param_dict, ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
ROW_PRUNING_SCHEDULE_OFFSET,
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else: else:
output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
...@@ -347,22 +307,17 @@ def get_row_pruning_different_groups(param_dict): ...@@ -347,22 +307,17 @@ def get_row_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS] sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict): def get_params(name, group_dict):
assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}" assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
return group_dict return group_dict
for k, v in sub_param_dict.items(): for k, v in sub_param_dict.items():
output[k] = {} output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
k, output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k], sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output return output
...@@ -375,7 +330,8 @@ def get_head_pruning(param_dict): ...@@ -375,7 +330,8 @@ def get_head_pruning(param_dict):
output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict) output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
# each sub-groups # each sub-groups
if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]: if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified" assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict) output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
return output return output
...@@ -384,19 +340,18 @@ def get_head_pruning_shared_parameters(param_dict): ...@@ -384,19 +340,18 @@ def get_head_pruning_shared_parameters(param_dict):
output = {} output = {}
if SHARED_PARAMETERS in param_dict.keys(): if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS] sub_param_dict = param_dict[SHARED_PARAMETERS]
output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, HEAD_PRUNING_ENABLED,
HEAD_PRUNING_ENABLED,
HEAD_PRUNING_ENABLED_DEFAULT) HEAD_PRUNING_ENABLED_DEFAULT)
output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict, output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict, HEAD_PRUNING_METHOD,
HEAD_PRUNING_METHOD,
HEAD_PRUNING_METHOD_DEFAULT) HEAD_PRUNING_METHOD_DEFAULT)
assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]" assert output[HEAD_PRUNING_METHOD] in [
output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param( HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK
sub_param_dict, ], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
HEAD_PRUNING_SCHEDULE_OFFSET, output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, HEAD_PRUNING_SCHEDULE_OFFSET,
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT) HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
if output[HEAD_PRUNING_ENABLED]: if output[HEAD_PRUNING_ENABLED]:
assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning" assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(
), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS] output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
else: else:
output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
...@@ -410,22 +365,17 @@ def get_head_pruning_different_groups(param_dict): ...@@ -410,22 +365,17 @@ def get_head_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS] sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict): def get_params(name, group_dict):
assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}" assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(
), f"dense_ratio must be specified for head pruning group {name}"
return group_dict return group_dict
for k, v in sub_param_dict.items(): for k, v in sub_param_dict.items():
output[k] = {} output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
k, output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k], sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output return output
...@@ -438,7 +388,8 @@ def get_channel_pruning(param_dict): ...@@ -438,7 +388,8 @@ def get_channel_pruning(param_dict):
output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict) output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
# each sub-groups # each sub-groups
if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]: if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified" assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict) output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
return output return output
...@@ -447,19 +398,15 @@ def get_channel_pruning_shared_parameters(param_dict): ...@@ -447,19 +398,15 @@ def get_channel_pruning_shared_parameters(param_dict):
output = {} output = {}
if SHARED_PARAMETERS in param_dict.keys(): if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS] sub_param_dict = param_dict[SHARED_PARAMETERS]
output[CHANNEL_PRUNING_ENABLED] = get_scalar_param( output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_ENABLED,
sub_param_dict, CHANNEL_PRUNING_ENABLED_DEFAULT)
CHANNEL_PRUNING_ENABLED, output[CHANNEL_PRUNING_METHOD] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_METHOD,
CHANNEL_PRUNING_ENABLED_DEFAULT) CHANNEL_PRUNING_METHOD_DEFAULT)
output[CHANNEL_PRUNING_METHOD] = get_scalar_param( assert output[CHANNEL_PRUNING_METHOD] in [
sub_param_dict, CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK
CHANNEL_PRUNING_METHOD, ], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
CHANNEL_PRUNING_METHOD_DEFAULT) output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_SCHEDULE_OFFSET,
assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]" CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_SCHEDULE_OFFSET,
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else: else:
output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
...@@ -472,21 +419,16 @@ def get_channel_pruning_different_groups(param_dict): ...@@ -472,21 +419,16 @@ def get_channel_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS] sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict): def get_params(name, group_dict):
assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}" assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
return group_dict return group_dict
for k, v in sub_param_dict.items(): for k, v in sub_param_dict.items():
output[k] = {} output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params( output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
k, output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS]) DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param( output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k], sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output return output
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
######################################### #########################################
# Compression Methods # Compression Methods
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
...@@ -59,23 +62,17 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None) ...@@ -59,23 +62,17 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
need_bias = True need_bias = True
# Initialize the new module # Initialize the new module
if isinstance(old_module, if isinstance(old_module, LinearLayer_Compress) or isinstance(old_module, torch.nn.Linear):
LinearLayer_Compress) or isinstance(old_module,
torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress): if isinstance(old_module, LinearLayer_Compress):
new_module = old_module new_module = old_module
else: else:
new_module = LinearLayer_Compress(old_module.in_features, new_module = LinearLayer_Compress(old_module.in_features, old_module.out_features,
old_module.out_features, bias=need_bias).to(device=old_module.weight.device,
bias=need_bias).to( dtype=old_module.weight.dtype)
device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data new_module.weight.data = old_module.weight.data
if need_bias: if need_bias:
new_module.bias.data = old_module.bias.data new_module.bias.data = old_module.bias.data
elif isinstance(old_module, elif isinstance(old_module, Conv2dLayer_Compress) or isinstance(old_module, torch.nn.Conv2d):
Conv2dLayer_Compress) or isinstance(old_module,
torch.nn.Conv2d):
if isinstance(old_module, Conv2dLayer_Compress): if isinstance(old_module, Conv2dLayer_Compress):
new_module = old_module new_module = old_module
else: else:
...@@ -86,60 +83,48 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None) ...@@ -86,60 +83,48 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
if need_bias: if need_bias:
new_module.bias.data = old_module.bias.data new_module.bias.data = old_module.bias.data
elif isinstance(old_module, torch.nn.BatchNorm2d): elif isinstance(old_module, torch.nn.BatchNorm2d):
new_module = BNLayer_Compress(old_module.num_features, new_module = BNLayer_Compress(old_module.num_features, old_module.eps, old_module.momentum, old_module.affine,
old_module.eps, old_module.track_running_stats).to(old_module.weight.device,
old_module.momentum, old_module.weight.dtype)
old_module.affine,
old_module.track_running_stats).to(
old_module.weight.device,
old_module.weight.dtype)
new_module.weight.data = old_module.weight.data new_module.weight.data = old_module.weight.data
if need_bias: if need_bias:
new_module.bias.data = old_module.bias.data new_module.bias.data = old_module.bias.data
new_module.running_mean.data = old_module.running_mean.data new_module.running_mean.data = old_module.running_mean.data
new_module.running_var.data = old_module.running_var.data new_module.running_var.data = old_module.running_var.data
elif isinstance(old_module, elif isinstance(old_module, Embedding_Compress) or isinstance(old_module, torch.nn.Embedding):
Embedding_Compress) or isinstance(old_module,
torch.nn.Embedding):
if isinstance(old_module, Embedding_Compress): if isinstance(old_module, Embedding_Compress):
new_module = old_module new_module = old_module
else: else:
new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \ new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype) old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data new_module.weight.data = old_module.weight.data
elif mpu is not None and (isinstance(old_module, elif mpu is not None and (isinstance(old_module, ColumnParallelLinear_Compress)
ColumnParallelLinear_Compress) or isinstance(old_module, mpu.ColumnParallelLinear)):
or isinstance(old_module,
mpu.ColumnParallelLinear)):
if isinstance(old_module, ColumnParallelLinear_Compress): if isinstance(old_module, ColumnParallelLinear_Compress):
new_module = old_module new_module = old_module
else: else:
new_module = ColumnParallelLinear_Compress( new_module = ColumnParallelLinear_Compress(mpu,
mpu, old_module.input_size,
old_module.input_size, old_module.output_size,
old_module.output_size, gather_output=old_module.gather_output,
gather_output=old_module.gather_output, skip_bias_add=old_module.skip_bias_add,
skip_bias_add=old_module.skip_bias_add, bias=need_bias).to(device=old_module.weight.device,
bias=need_bias).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data new_module.weight.data = old_module.weight.data
if need_bias: if need_bias:
new_module.bias.data = old_module.bias.data new_module.bias.data = old_module.bias.data
elif mpu is not None and (isinstance(old_module, elif mpu is not None and (isinstance(old_module, RowParallelLinear_Compress)
RowParallelLinear_Compress) or isinstance(old_module, mpu.RowParallelLinear)):
or isinstance(old_module,
mpu.RowParallelLinear)):
if isinstance(old_module, RowParallelLinear_Compress): if isinstance(old_module, RowParallelLinear_Compress):
new_module = old_module new_module = old_module
else: else:
new_module = RowParallelLinear_Compress( new_module = RowParallelLinear_Compress(mpu,
mpu, old_module.input_size,
old_module.input_size, old_module.output_size,
old_module.output_size, input_is_parallel=old_module.input_is_parallel,
input_is_parallel=old_module.input_is_parallel, skip_bias_add=old_module.skip_bias_add,
skip_bias_add=old_module.skip_bias_add, bias=need_bias).to(device=old_module.weight.device,
bias=need_bias).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data new_module.weight.data = old_module.weight.data
if need_bias: if need_bias:
new_module.bias.data = old_module.bias.data new_module.bias.data = old_module.bias.data
...@@ -150,39 +135,30 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None) ...@@ -150,39 +135,30 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
for k, v in compression_technique.items(): for k, v in compression_technique.items():
if k == SPARSE_PRUNING: if k == SPARSE_PRUNING:
if v[SPARSE_PRUNING_ENABLED]: if v[SPARSE_PRUNING_ENABLED]:
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], v[SPARSE_PRUNING_METHOD])
v[SPARSE_PRUNING_METHOD])
elif k == ROW_PRUNING: elif k == ROW_PRUNING:
if v[ROW_PRUNING_ENABLED]: if v[ROW_PRUNING_ENABLED]:
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], v[ROW_PRUNING_METHOD])
v[ROW_PRUNING_METHOD])
elif k == HEAD_PRUNING: elif k == HEAD_PRUNING:
if v[HEAD_PRUNING_ENABLED]: if v[HEAD_PRUNING_ENABLED]:
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], v[HEAD_PRUNING_METHOD],
v[HEAD_PRUNING_METHOD],
v[HEAD_PRUNING_NUM_HEADS]) v[HEAD_PRUNING_NUM_HEADS])
elif k == ACTIVATION_QUANTIZATION: elif k == ACTIVATION_QUANTIZATION:
if v[ACTIVATION_QUANTIZATION_ENABLED]: if v[ACTIVATION_QUANTIZATION_ENABLED]:
new_module.enable_activation_quantization( new_module.enable_activation_quantization(v[ACTIVATION_QUANTIZE_BITS], v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_BITS], v[ACTIVATION_QUANTIZE_RANGE])
v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
elif k == WEIGHT_QUANTIZATION: elif k == WEIGHT_QUANTIZATION:
if v[WEIGHT_QUANTIZE_ENABLED]: if v[WEIGHT_QUANTIZE_ENABLED]:
new_module.enable_weight_quantization( new_module.enable_weight_quantization(v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_START_BITS], v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS], v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZATION_PERIOD], v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], v[WEIGHT_QUANTIZE_TYPE], v[WEIGHT_QUANTIZE_GROUPS])
v[WEIGHT_QUANTIZE_TYPE],
v[WEIGHT_QUANTIZE_GROUPS])
elif k == CHANNEL_PRUNING: elif k == CHANNEL_PRUNING:
if v[CHANNEL_PRUNING_ENABLED]: if v[CHANNEL_PRUNING_ENABLED]:
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], v[CHANNEL_PRUNING_METHOD])
v[CHANNEL_PRUNING_METHOD])
else: else:
raise NotImplementedError( raise NotImplementedError('Compression technique {} is not implemented'.format(k))
'Compression technique {} is not implemented'.format(k))
# Replace the old module with the new one # Replace the old module with the new one
recursive_setattr(model, module_name, new_module) recursive_setattr(model, module_name, new_module)
...@@ -195,10 +171,7 @@ def is_module_compressible(module, mpu=None): ...@@ -195,10 +171,7 @@ def is_module_compressible(module, mpu=None):
isinstance(module, torch.nn.BatchNorm2d) isinstance(module, torch.nn.BatchNorm2d)
if mpu is not None: if mpu is not None:
ret = ret or isinstance(module, ret = ret or isinstance(module, mpu.RowParallelLinear) or isinstance(module, mpu.ColumnParallelLinear)
mpu.RowParallelLinear) or isinstance(
module,
mpu.ColumnParallelLinear)
return ret return ret
...@@ -225,11 +198,7 @@ def compression_preparation(model, compression_techinique_list, mpu): ...@@ -225,11 +198,7 @@ def compression_preparation(model, compression_techinique_list, mpu):
return model return model
def fix_compression(model, def fix_compression(model, module_name, compression_technique, mask=None, dim_reduction=False):
module_name,
compression_technique,
mask=None,
dim_reduction=False):
""" """
Fix the compression technique of a module. Fix the compression technique of a module.
Args: Args:
...@@ -243,17 +212,14 @@ def fix_compression(model, ...@@ -243,17 +212,14 @@ def fix_compression(model,
# Here we can make things much simpler by just replacing the module # Here we can make things much simpler by just replacing the module
module = recursive_getattr(model, module_name) module = recursive_getattr(model, module_name)
for k, v in compression_technique.items(): for k, v in compression_technique.items():
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[ if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[WEIGHT_QUANTIZE_ENABLED]:
WEIGHT_QUANTIZE_ENABLED]:
return module.fix_weight_quantization() return module.fix_weight_quantization()
elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]: elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
return module.fix_sparse_pruning_helper() return module.fix_sparse_pruning_helper()
elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None): elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction) return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None): elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
return module.fix_head_pruning_helper(mask, return module.fix_head_pruning_helper(mask, v[HEAD_PRUNING_NUM_HEADS], dim_reduction=dim_reduction)
v[HEAD_PRUNING_NUM_HEADS],
dim_reduction=dim_reduction)
elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None): elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction) return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
...@@ -270,10 +236,9 @@ def convert_conv1d_to_linear(model, convert_type): ...@@ -270,10 +236,9 @@ def convert_conv1d_to_linear(model, convert_type):
for name, module in c_model.named_modules(): for name, module in c_model.named_modules():
if isinstance(module, convert_type): if isinstance(module, convert_type):
old_module = recursive_getattr(c_model, name) old_module = recursive_getattr(c_model, name)
new_module = torch.nn.Linear( new_module = torch.nn.Linear(old_module.weight.data.size(0),
old_module.weight.data.size(0), old_module.weight.data.size(1),
old_module.weight.data.size(1), bias=True if old_module.bias is not None else False)
bias=True if old_module.bias is not None else False)
new_module.weight.data = old_module.weight.data.t().contiguous() new_module.weight.data = old_module.weight.data.t().contiguous()
if new_module.bias is not None: if new_module.bias is not None:
new_module.bias.data = old_module.bias.data.view(-1) new_module.bias.data = old_module.bias.data.view(-1)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .compress import get_module_name from .compress import get_module_name
from .constants import * from .constants import *
...@@ -10,6 +13,7 @@ class compression_scheduler(): ...@@ -10,6 +13,7 @@ class compression_scheduler():
''' '''
Used to schedule different compression methods Used to schedule different compression methods
''' '''
def __init__(self, model, compression_config): def __init__(self, model, compression_config):
self.model = model self.model = model
self.compression_config = compression_config self.compression_config = compression_config
...@@ -38,22 +42,22 @@ class compression_scheduler(): ...@@ -38,22 +42,22 @@ class compression_scheduler():
} }
exist_module_name = set() exist_module_name = set()
shared_parameters = method_content[SHARED_PARAMETERS] shared_parameters = method_content[SHARED_PARAMETERS]
self.different_compression_methods[method][ self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED] self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters
self.different_compression_methods[method][
SHARED_PARAMETERS] = shared_parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items(): for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
module_name_list = [] module_name_list = []
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]: for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, self.model, key_word, exist_module_name, verbose=False) module_name, exist_module_name = get_module_name(group_name,
self.model,
key_word,
exist_module_name,
verbose=False)
module_name_list.extend(module_name) module_name_list.extend(module_name)
if module_name_list: if module_name_list:
self.different_compression_methods[method][DIFFERENT_GROUPS].append([ self.different_compression_methods[method][DIFFERENT_GROUPS].append(
group_name, [group_name, module_name_list,
module_name_list, method_parameters.copy().pop('params')])
method_parameters.copy().pop('params')
])
def check_weight_quantization(self): def check_weight_quantization(self):
# check weight quantization # check weight quantization
...@@ -69,8 +73,7 @@ class compression_scheduler(): ...@@ -69,8 +73,7 @@ class compression_scheduler():
module.weight_quantization_enabled = True module.weight_quantization_enabled = True
if not self.verbose[WEIGHT_QUANTIZATION]: if not self.verbose[WEIGHT_QUANTIZATION]:
logger.info( logger.info(f'Weight quantization is enabled at step {self.training_steps}')
f'Weight quantization is enabled at step {self.training_steps}')
self.weight_quantization_enabled = True self.weight_quantization_enabled = True
self.verbose[WEIGHT_QUANTIZATION] = True self.verbose[WEIGHT_QUANTIZATION] = True
...@@ -87,9 +90,7 @@ class compression_scheduler(): ...@@ -87,9 +90,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name) module = recursive_getattr(self.model, module_name)
module.activation_quantization_enabled = True module.activation_quantization_enabled = True
if not self.verbose[ACTIVATION_QUANTIZATION]: if not self.verbose[ACTIVATION_QUANTIZATION]:
logger.info( logger.info(f'Activation quantization is enabled at step {self.training_steps}')
f'Activation quantization is enabled at step {self.training_steps}'
)
self.verbose[ACTIVATION_QUANTIZATION] = True self.verbose[ACTIVATION_QUANTIZATION] = True
def check_sparse_pruning(self): def check_sparse_pruning(self):
...@@ -105,8 +106,7 @@ class compression_scheduler(): ...@@ -105,8 +106,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name) module = recursive_getattr(self.model, module_name)
module.sparse_pruning_enabled = True module.sparse_pruning_enabled = True
if not self.verbose[SPARSE_PRUNING]: if not self.verbose[SPARSE_PRUNING]:
logger.info( logger.info(f'Sparse pruning is enabled at step {self.training_steps}')
f'Sparse pruning is enabled at step {self.training_steps}')
self.verbose[SPARSE_PRUNING] = True self.verbose[SPARSE_PRUNING] = True
def check_head_pruning(self): def check_head_pruning(self):
...@@ -154,8 +154,7 @@ class compression_scheduler(): ...@@ -154,8 +154,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name) module = recursive_getattr(self.model, module_name)
module.channel_pruning_enabled = True module.channel_pruning_enabled = True
if not self.verbose[CHANNEL_PRUNING]: if not self.verbose[CHANNEL_PRUNING]:
logger.info( logger.info(f'Channel pruning is enabled at step {self.training_steps}')
f'Channel pruning is enabled at step {self.training_steps}')
self.verbose[CHANNEL_PRUNING] = True self.verbose[CHANNEL_PRUNING] = True
def check_all_modules(self): def check_all_modules(self):
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from torch import autograd from torch import autograd
...@@ -13,6 +16,7 @@ class TopKBinarizer(autograd.Function): ...@@ -13,6 +16,7 @@ class TopKBinarizer(autograd.Function):
Implementation is inspired from: Implementation is inspired from:
https://github.com/yaozhewei/MLPruning https://github.com/yaozhewei/MLPruning
""" """
@staticmethod @staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool): def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
""" """
...@@ -59,6 +63,7 @@ class SymQuantizer(torch.autograd.Function): ...@@ -59,6 +63,7 @@ class SymQuantizer(torch.autograd.Function):
""" """
Symmetric quantization Symmetric quantization
""" """
@staticmethod @staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
""" """
...@@ -75,9 +80,8 @@ class SymQuantizer(torch.autograd.Function): ...@@ -75,9 +80,8 @@ class SymQuantizer(torch.autograd.Function):
quantized_input (`torch.FloatTensor`) quantized_input (`torch.FloatTensor`)
Quantized input Quantized input
""" """
assert (min_value is None assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and max_value is None) or (min_value is not None and num_groups == 1)
and max_value is not None and num_groups == 1)
q_range = 2**num_bits q_range = 2**num_bits
input_shape = input.shape input_shape = input.shape
if min_value is None: if min_value is None:
...@@ -101,6 +105,7 @@ class AsymQuantizer(torch.autograd.Function): ...@@ -101,6 +105,7 @@ class AsymQuantizer(torch.autograd.Function):
""" """
Asymmetric quantization Asymmetric quantization
""" """
@staticmethod @staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
""" """
...@@ -118,9 +123,8 @@ class AsymQuantizer(torch.autograd.Function): ...@@ -118,9 +123,8 @@ class AsymQuantizer(torch.autograd.Function):
Quantized input Quantized input
""" """
assert (min_value is None assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and max_value is None) or (min_value is not None and num_groups == 1)
and max_value is not None and num_groups == 1)
q_range = 2**num_bits q_range = 2**num_bits
input_shape = input.shape input_shape = input.shape
if min_value is None: if min_value is None:
...@@ -131,9 +135,7 @@ class AsymQuantizer(torch.autograd.Function): ...@@ -131,9 +135,7 @@ class AsymQuantizer(torch.autograd.Function):
scale = (max_value - min_value) / q_range scale = (max_value - min_value) / q_range
zero_point = (min_value / scale).round() * scale zero_point = (min_value / scale).round() * scale
output = ( output = ((input - zero_point) / scale).round().clamp(0, q_range - 1) * scale + zero_point
(input - zero_point) / scale).round().clamp(0,
q_range - 1) * scale + zero_point
output = output.reshape(input_shape).contiguous() output = output.reshape(input_shape).contiguous()
return output return output
...@@ -147,6 +149,7 @@ class TernaryQuantizer(torch.autograd.Function): ...@@ -147,6 +149,7 @@ class TernaryQuantizer(torch.autograd.Function):
""" """
Ternary quantization Ternary quantization
""" """
@staticmethod @staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
""" """
...@@ -187,6 +190,7 @@ class BinaryQuantizer(torch.autograd.Function): ...@@ -187,6 +190,7 @@ class BinaryQuantizer(torch.autograd.Function):
""" """
Binary quantization Binary quantization
""" """
@staticmethod @staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1): def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
""" """
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from datetime import timedelta from datetime import timedelta
############################################# #############################################
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
from .utils import is_torch_elastic_compatible from .utils import is_torch_elastic_compatible
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import json import json
from .constants import * from .constants import *
...@@ -43,77 +44,64 @@ class ElasticityConfig: ...@@ -43,77 +44,64 @@ class ElasticityConfig:
"version": 0.1 "version": 0.1
} }
""" """
def __init__(self, param_dict): def __init__(self, param_dict):
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT) self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
if self.enabled: if self.enabled:
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict: if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE] self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
else: else:
raise ElasticityConfigError( raise ElasticityConfigError(f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
if MICRO_BATCHES in param_dict: if MICRO_BATCHES in param_dict:
self.micro_batches = param_dict[MICRO_BATCHES] self.micro_batches = param_dict[MICRO_BATCHES]
else: else:
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}") raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
else: else:
self.max_acceptable_batch_size = param_dict.get( self.max_acceptable_batch_size = param_dict.get(MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE, MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT) self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
if not isinstance(self.micro_batches, list): if not isinstance(self.micro_batches, list):
raise ElasticityConfigError( raise ElasticityConfigError(
f"Elasticity expected value of {MICRO_BATCHES} to be a " f"Elasticity expected value of {MICRO_BATCHES} to be a "
f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}" f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}")
)
if not all(map(lambda m: isinstance(m, int), self.micro_batches)): if not all(map(lambda m: isinstance(m, int), self.micro_batches)):
raise ElasticityConfigError( raise ElasticityConfigError(f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, " f"instead contains: f{self.micro_batches}")
f"instead contains: f{self.micro_batches}")
if not all(map(lambda m: m > 0, self.micro_batches)): if not all(map(lambda m: m > 0, self.micro_batches)):
raise ElasticityConfigError( raise ElasticityConfigError(f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, " f"instead contains: f{self.micro_batches}")
f"instead contains: f{self.micro_batches}")
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT) self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT) self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
if self.min_gpus < 1 or self.max_gpus < 1: if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError( raise ElasticityConfigError("Elasticity min/max gpus must be > 0, "
"Elasticity min/max gpus must be > 0, " f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
if self.max_gpus < self.min_gpus: if self.max_gpus < self.min_gpus:
raise ElasticityConfigError( raise ElasticityConfigError("Elasticity min_gpus cannot be greater than max_gpus, "
"Elasticity min_gpus cannot be greater than max_gpus, " f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE, self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE, MODEL_PARLLEL_SIZE_DEFAULT)
MODEL_PARLLEL_SIZE_DEFAULT)
if self.model_parallel_size < 1: if self.model_parallel_size < 1:
raise ElasticityConfigError( raise ElasticityConfigError("Model-Parallel size cannot be less than 1, "
"Model-Parallel size cannot be less than 1, " f"given model-parallel size: {self.model_parallel_size}")
f"given model-parallel size: {self.model_parallel_size}")
self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE, self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT)
NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1: if self.num_gpus_per_node < 1:
raise ElasticityConfigError( raise ElasticityConfigError("Number of GPUs per node cannot be less than 1, "
"Number of GPUs per node cannot be less than 1, " f"given number of GPUs per node: {self.num_gpus_per_node}")
f"given number of GPUs per node: {self.num_gpus_per_node}")
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT) self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0: if self.min_time < 0:
raise ElasticityConfigError( raise ElasticityConfigError(f"Elasticity min time needs to be >= 0: given {self.min_time}")
f"Elasticity min time needs to be >= 0: given {self.min_time}")
self.version = param_dict.get(VERSION, VERSION_DEFAULT) self.version = param_dict.get(VERSION, VERSION_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, PREFER_LARGER_BATCH_DEFAULT)
PREFER_LARGER_BATCH_DEFAULT) self.ignore_non_elastic_batch_info = param_dict.get(IGNORE_NON_ELASTIC_BATCH_INFO,
self.ignore_non_elastic_batch_info = param_dict.get( IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
def repr(self): def repr(self):
return self.__dict__ return self.__dict__
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
######################################### #########################################
# Elasticity # Elasticity
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
...@@ -23,6 +26,7 @@ import subprocess ...@@ -23,6 +26,7 @@ import subprocess
class DSElasticAgent(LocalElasticAgent): class DSElasticAgent(LocalElasticAgent):
def __init__( def __init__(
self, self,
spec: WorkerSpec, spec: WorkerSpec,
...@@ -35,9 +39,7 @@ class DSElasticAgent(LocalElasticAgent): ...@@ -35,9 +39,7 @@ class DSElasticAgent(LocalElasticAgent):
self.ds_env = env self.ds_env = env
@staticmethod @staticmethod
def _set_master_addr_port(store: Store, def _set_master_addr_port(store: Store, master_addr: Optional[str], master_port: Optional[int]):
master_addr: Optional[str],
master_port: Optional[int]):
if master_port is None: if master_port is None:
sock = _get_socket_with_port() sock = _get_socket_with_port()
with closing(sock): with closing(sock):
...@@ -82,8 +84,7 @@ class DSElasticAgent(LocalElasticAgent): ...@@ -82,8 +84,7 @@ class DSElasticAgent(LocalElasticAgent):
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", "NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)),
str(1)),
} }
worker_env_ds.update(worker_env_elastic) worker_env_ds.update(worker_env_elastic)
if "OMP_NUM_THREADS" in os.environ: if "OMP_NUM_THREADS" in os.environ:
...@@ -120,8 +121,7 @@ class DSElasticAgent(LocalElasticAgent): ...@@ -120,8 +121,7 @@ class DSElasticAgent(LocalElasticAgent):
spec = self._worker_group.spec spec = self._worker_group.spec
role = spec.role role = spec.role
log.info( log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
self._initialize_workers(self._worker_group) self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval monitor_interval = spec.monitor_interval
...@@ -136,13 +136,10 @@ class DSElasticAgent(LocalElasticAgent): ...@@ -136,13 +136,10 @@ class DSElasticAgent(LocalElasticAgent):
state = run_result.state state = run_result.state
self._worker_group.state = state self._worker_group.state = state
expire_time = datetime.utcnow() - ( expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_interval * rdzv_handler._settings.keep_alive_max_attempt)
rdzv_handler._settings.keep_alive_max_attempt)
_dead_nodes = [ _dead_nodes = [
node for node, node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items()
last_heartbeat in
rdzv_handler._state_holder.state.last_heartbeats.items()
if last_heartbeat < expire_time if last_heartbeat < expire_time
] ]
...@@ -150,21 +147,16 @@ class DSElasticAgent(LocalElasticAgent): ...@@ -150,21 +147,16 @@ class DSElasticAgent(LocalElasticAgent):
put_metric(f"workers.{role}.{state.name.lower()}", 1) put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED: if state == WorkerState.SUCCEEDED:
log.info( log.info(f"[{role}] worker group successfully finished."
f"[{role}] worker group successfully finished." f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.")
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
self._exit_barrier() self._exit_barrier()
return run_result return run_result
elif state in { elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED
WorkerState.UNHEALTHY, } or len(participants) > len(rdzv_handler._state_holder.state.participants):
WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0: if self._remaining_restarts > 0:
log.info( log.info(f"[{role}] Worker group {state.name}. "
f"[{role}] Worker group {state.name}. " f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" f" will restart worker group")
f" will restart worker group")
self._remaining_restarts -= 1 self._remaining_restarts -= 1
# rdzv_handler._state_holder.state.restart = False # rdzv_handler._state_holder.state.restart = False
self._restart_workers(self._worker_group) self._restart_workers(self._worker_group)
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import os import os
import json import json
import numpy as np import numpy as np
...@@ -17,44 +19,8 @@ from ..utils import logger ...@@ -17,44 +19,8 @@ from ..utils import logger
# Thirty eight smallest highly composite numbers. The list should # Thirty eight smallest highly composite numbers. The list should
# be enough to support up to 720K batch size. # be enough to support up to 720K batch size.
HCN_LIST = [ HCN_LIST = [
1, 1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680, 2520, 5040, 7560, 10080, 15120, 20160,
2, 25200, 27720, 45360, 50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280, 720720
4,
6,
12,
24,
36,
48,
60,
120,
180,
240,
360,
720,
840,
1260,
1680,
2520,
5040,
7560,
10080,
15120,
20160,
25200,
27720,
45360,
50400,
55440,
83160,
110880,
166320,
221760,
277200,
332640,
498960,
554400,
665280,
720720
] ]
...@@ -94,11 +60,7 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus): ...@@ -94,11 +60,7 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
return valid_gpus return valid_gpus
def get_best_candidates(candidate_batch_sizes, def get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus, prefer_larger):
micro_batches,
min_gpus,
max_gpus,
prefer_larger):
max_valid_gpus = 0 max_valid_gpus = 0
valid_gpus = None valid_gpus = None
...@@ -106,15 +68,11 @@ def get_best_candidates(candidate_batch_sizes, ...@@ -106,15 +68,11 @@ def get_best_candidates(candidate_batch_sizes,
for batch_size in candidate_batch_sizes: for batch_size in candidate_batch_sizes:
current_valid_gpus = get_valid_gpus(batch_size, current_valid_gpus = get_valid_gpus(batch_size, micro_batches, min_gpus, max_gpus)
micro_batches,
min_gpus,
max_gpus)
if (len(current_valid_gpus) > max_valid_gpus if (len(current_valid_gpus) > max_valid_gpus or (len(current_valid_gpus) == max_valid_gpus and
or (len(current_valid_gpus) == max_valid_gpus and ((prefer_larger and batch_size > final_batch_size) or
((prefer_larger and batch_size > final_batch_size) or (not prefer_larger and batch_size < final_batch_size)))):
(not prefer_larger and batch_size < final_batch_size)))):
max_valid_gpus = len(current_valid_gpus) max_valid_gpus = len(current_valid_gpus)
valid_gpus = current_valid_gpus valid_gpus = current_valid_gpus
final_batch_size = batch_size final_batch_size = batch_size
...@@ -157,15 +115,10 @@ def _get_compatible_gpus_v01(micro_batches, ...@@ -157,15 +115,10 @@ def _get_compatible_gpus_v01(micro_batches,
base_list.extend(micro_batches) base_list.extend(micro_batches)
base_list.append(lcm) base_list.append(lcm)
candidate_batch_sizes = get_candidate_batch_sizes(base_list, candidate_batch_sizes = get_candidate_batch_sizes(base_list, max_acceptable_batch_size)
max_acceptable_batch_size)
final_batch_size, valid_gpus = get_best_candidates( final_batch_size, valid_gpus = get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus,
candidate_batch_sizes, prefer_larger)
micro_batches,
min_gpus,
max_gpus,
prefer_larger)
return final_batch_size, valid_gpus return final_batch_size, valid_gpus
...@@ -203,11 +156,12 @@ def _get_compatible_gpus_v02(micro_batches, ...@@ -203,11 +156,12 @@ def _get_compatible_gpus_v02(micro_batches,
dp_size_per_node = num_gpus_per_node // model_parallel_size dp_size_per_node = num_gpus_per_node // model_parallel_size
final_batch_size, valid_world_size = _get_compatible_gpus_v01(micro_batches, final_batch_size, valid_world_size = _get_compatible_gpus_v01(
int(max_acceptable_batch_size/dp_size_per_node), micro_batches,
int(min_gpus/num_gpus_per_node), int(max_acceptable_batch_size / dp_size_per_node),
int(max_gpus/num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level int(min_gpus / num_gpus_per_node),
prefer_larger=prefer_larger) int(max_gpus / num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
prefer_larger=prefer_larger)
final_batch_size = int(final_batch_size) * dp_size_per_node final_batch_size = int(final_batch_size) * dp_size_per_node
valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size] valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size]
...@@ -256,38 +210,27 @@ def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict): ...@@ -256,38 +210,27 @@ def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
Ensure the resource scheduler saw the same elastic config we are using at runtime Ensure the resource scheduler saw the same elastic config we are using at runtime
""" """
if DEEPSPEED_ELASTICITY_CONFIG in os.environ: if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
scheduler_elastic_config_dict = json.loads( scheduler_elastic_config_dict = json.loads(os.environ[DEEPSPEED_ELASTICITY_CONFIG])
os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict) scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict) runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}" err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size: if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
raise ElasticityConfigError( raise ElasticityConfigError(
err_str.format('max_acceptable_batch_size', err_str.format('max_acceptable_batch_size', scheduler_elastic_config.max_acceptable_batch_size,
scheduler_elastic_config.max_acceptable_batch_size, 'max_acceptable_batch_size', runtime_elastic_config.max_acceptable_batch_size))
'max_acceptable_batch_size',
runtime_elastic_config.max_acceptable_batch_size))
if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches: if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
raise ElasticityConfigError( raise ElasticityConfigError(
err_str.format('micro_batches', err_str.format('micro_batches', scheduler_elastic_config.micro_batches, 'micro_batches',
scheduler_elastic_config.micro_batches,
'micro_batches',
runtime_elastic_config.micro_batches)) runtime_elastic_config.micro_batches))
if runtime_elastic_config.version != scheduler_elastic_config.version: if runtime_elastic_config.version != scheduler_elastic_config.version:
raise ElasticityConfigError( raise ElasticityConfigError(
err_str.format('version', err_str.format('version', scheduler_elastic_config.version, 'version', runtime_elastic_config.version))
scheduler_elastic_config.version,
'version',
runtime_elastic_config.version))
else: else:
logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \ logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
"guarantee resource scheduler will scale this job using compatible GPU counts.") "guarantee resource scheduler will scale this job using compatible GPU counts.")
def compute_elastic_config(ds_config: dict, def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0, return_microbatch=False):
target_deepspeed_version: str,
world_size=0,
return_microbatch=False):
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below) """Core deepspeed elasticity API. Given an elastic config (similar to the example below)
DeepSpeed will compute a total train batch size corresponding valid GPU count list that DeepSpeed will compute a total train batch size corresponding valid GPU count list that
provides a high level of elasticity. Elasticity in this case means we are safe to scale provides a high level of elasticity. Elasticity in this case means we are safe to scale
...@@ -397,8 +340,7 @@ def compute_elastic_config(ds_config: dict, ...@@ -397,8 +340,7 @@ def compute_elastic_config(ds_config: dict,
# ensure batch size is int dtype # ensure batch size is int dtype
final_batch_size = int(final_batch_size) final_batch_size = int(final_batch_size)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Unable to find elastic logic for version: {elastic_config.version}")
f"Unable to find elastic logic for version: {elastic_config.version}")
logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}") logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}")
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
import deepspeed import deepspeed
...@@ -48,8 +51,7 @@ def op_report(verbose=True): ...@@ -48,8 +51,7 @@ def op_report(verbose=True):
dots = "." * (max_dots - len(op_name)) dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible(verbose) else no is_compatible = OKAY if builder.is_compatible(verbose) else no
is_installed = installed if installed_ops[op_name] else no is_installed = installed if installed_ops[op_name] else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
(len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible) print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1]))) print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
...@@ -68,9 +70,7 @@ def nvcc_version(): ...@@ -68,9 +70,7 @@ def nvcc_version():
if cuda_home is None: if cuda_home is None:
return f"{RED} [FAIL] cannot find CUDA_HOME via torch.utils.cpp_extension.CUDA_HOME={torch.utils.cpp_extension.CUDA_HOME} {END}" return f"{RED} [FAIL] cannot find CUDA_HOME via torch.utils.cpp_extension.CUDA_HOME={torch.utils.cpp_extension.CUDA_HOME} {END}"
try: try:
output = subprocess.check_output([cuda_home + "/bin/nvcc", output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
"-V"],
universal_newlines=True)
except FileNotFoundError: except FileNotFoundError:
return f"{RED} [FAIL] nvcc missing {END}" return f"{RED} [FAIL] nvcc missing {END}"
output_split = output.split() output_split = output.split()
...@@ -82,32 +82,18 @@ def nvcc_version(): ...@@ -82,32 +82,18 @@ def nvcc_version():
def debug_report(): def debug_report():
max_dots = 33 max_dots = 33
report = [ report = [("torch install path", torch.__path__), ("torch version", torch.__version__),
("torch install path", ("deepspeed install path", deepspeed.__path__),
torch.__path__), ("deepspeed info", f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}")]
("torch version",
torch.__version__),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
)
]
if get_accelerator().device_name() == 'cuda': if get_accelerator().device_name() == 'cuda':
hip_version = getattr(torch.version, "hip", None) hip_version = getattr(torch.version, "hip", None)
report.extend([("torch cuda version", report.extend([("torch cuda version", torch.version.cuda), ("torch hip version", hip_version),
torch.version.cuda), ("nvcc version", (None if hip_version else nvcc_version())),
("torch hip version", ("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " +
hip_version), (f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}"))
("nvcc version", ])
(None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}"))])
else: else:
report.extend([("deepspeed wheel compiled w.", report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")])
f"torch {torch_info['version']} ")])
print("DeepSpeed general environment info:") print("DeepSpeed general environment info:")
for name, value in report: for name, value in report:
...@@ -116,15 +102,10 @@ def debug_report(): ...@@ -116,15 +102,10 @@ def debug_report():
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--hide_operator_status',
'--hide_operator_status',
action='store_true',
help=
'Suppress display of installation and compatibility statuses of DeepSpeed operators. '
)
parser.add_argument('--hide_errors_and_warnings',
action='store_true', action='store_true',
help='Suppress warning and error messages.') help='Suppress display of installation and compatibility statuses of DeepSpeed operators. ')
parser.add_argument('--hide_errors_and_warnings', action='store_true', help='Suppress warning and error messages.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -137,8 +118,7 @@ def main(hide_operator_status=False, hide_errors_and_warnings=False): ...@@ -137,8 +118,7 @@ def main(hide_operator_status=False, hide_errors_and_warnings=False):
def cli_main(): def cli_main():
args = parse_arguments() args = parse_arguments()
main(hide_operator_status=args.hide_operator_status, main(hide_operator_status=args.hide_operator_status, hide_errors_and_warnings=args.hide_errors_and_warnings)
hide_errors_and_warnings=args.hide_errors_and_warnings)
if __name__ == "__main__": if __name__ == "__main__":
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
try: try:
# This is populated by setup.py # This is populated by setup.py
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .engine import InferenceEngine from .engine import InferenceEngine
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
...@@ -12,10 +15,12 @@ from enum import Enum ...@@ -12,10 +15,12 @@ from enum import Enum
class DtypeEnum(Enum): class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype) # The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8" int8 = torch.int8, "torch.int8", "int8"
# bf16 not supported
# bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
# Copied from https://stackoverflow.com/a/43210118 # Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first # Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called # listed value when Enum is called
...@@ -192,6 +197,11 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): ...@@ -192,6 +197,11 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
This can be passed through the json config too. This can be passed through the json config too.
""" """
set_empty_params: bool = False
"""
specifying whether the inference-module is created with empty or real Tensor
"""
save_mp_checkpoint_path: str = None save_mp_checkpoint_path: str = None
""" """
The path for which we want to save the loaded model with a checkpoint. This The path for which we want to save the loaded model with a checkpoint. This
...@@ -222,9 +232,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): ...@@ -222,9 +232,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
replace_method: str = Field( replace_method: str = Field(
"auto", "auto",
deprecated=True, deprecated=True,
deprecated_msg= deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
"This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
)
injection_policy: Dict = Field(None, alias="injection_dict") injection_policy: Dict = Field(None, alias="injection_dict")
""" """
...@@ -235,9 +243,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): ...@@ -235,9 +243,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
injection_policy_tuple: tuple = None injection_policy_tuple: tuple = None
""" TODO: Add docs """ """ TODO: Add docs """
config: Dict = Field( config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
None,
alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens") max_out_tokens: int = Field(1024, alias="max_tokens")
""" """
...@@ -246,6 +252,16 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): ...@@ -246,6 +252,16 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
to the required token-length required for your use-case. to the required token-length required for your use-case.
""" """
min_out_tokens: int = Field(1, alias="min_tokens")
"""
This argument communicates to the runtime the minimum number of tokens you
expect you will need to generate. This will cause the runtime to error
if it unable to provide this and provide context on the memory pressure
rather than seg-faulting or providing corrupted output.
"""
transposed_mode: bool = Field(False, alias="transposed_mode")
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size") mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
""" """
Desired model parallel size, default is 1 meaning no model parallelism. Desired model parallel size, default is 1 meaning no model parallelism.
...@@ -254,18 +270,10 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): ...@@ -254,18 +270,10 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
""" """
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu") mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size") ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None, ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
alias="expert_group", ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
deprecated=True,
new_param="moe.ep_group")
ep_mp_group: object = Field(None,
alias="expert_mp_group",
deprecated=True,
new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts") moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
deprecated=True,
new_param="moe.type")
@validator("moe") @validator("moe")
def moe_backward_compat(cls, field_value, values): def moe_backward_compat(cls, field_value, values):
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
import time import time
import os import os
...@@ -32,6 +34,58 @@ from torch import nn ...@@ -32,6 +34,58 @@ from torch import nn
INFERENCE_MODEL_TIMER = "model-forward-inference" INFERENCE_MODEL_TIMER = "model-forward-inference"
def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
import math
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2**math.floor(math.log2(num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
else:
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
class InferenceEngine(Module): class InferenceEngine(Module):
inference_mp_group = None inference_mp_group = None
inference_ep_group = None inference_ep_group = None
...@@ -80,13 +134,18 @@ class InferenceEngine(Module): ...@@ -80,13 +134,18 @@ class InferenceEngine(Module):
self.checkpoint_engine = TorchCheckpointEngine() self.checkpoint_engine = TorchCheckpointEngine()
quantization_setting = None quantization_setting = None
self._init_quantization_setting( self._init_quantization_setting(
quantization_setting quantization_setting) # todo: update with the new quant config for weight quant
) # todo: update with the new quant config for weight quant
self.model_profile_enabled = False self.model_profile_enabled = False
self._model_times = [] self._model_times = []
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture if not self.injection_dict and config.replace_with_kernel_inject:
self.remove_mask_prepare_for_bloom() # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self.remove_mask_prepare_for_bloom()
if self.injection_dict or not config.replace_with_kernel_inject:
# This is a hack to redefine the alibi func due to TP
if config.tensor_parallel.tp_size > 1:
self.build_alibi_tensor()
if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph: if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
...@@ -100,8 +159,7 @@ class InferenceEngine(Module): ...@@ -100,8 +159,7 @@ class InferenceEngine(Module):
self._convert_to_dtype(config) self._convert_to_dtype(config)
if self.mpu: if self.mpu:
config.tensor_parallel.tp_size = dist.get_world_size( config.tensor_parallel.tp_size = dist.get_world_size(group=self.mpu.get_model_parallel_group())
group=self.mpu.get_model_parallel_group())
self.mp_group = self.mpu.get_model_parallel_group() self.mp_group = self.mpu.get_model_parallel_group()
elif config.tensor_parallel.tp_size > 1: elif config.tensor_parallel.tp_size > 1:
self._create_model_parallel_group(config) self._create_model_parallel_group(config)
...@@ -149,8 +207,7 @@ class InferenceEngine(Module): ...@@ -149,8 +207,7 @@ class InferenceEngine(Module):
self.module.to(device) self.module.to(device)
if config.tensor_parallel.tp_size > 1: if config.tensor_parallel.tp_size > 1:
_rng_state = get_accelerator().get_rng_state().to( _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0) dist.broadcast(_rng_state, 0)
get_accelerator().set_rng_state(_rng_state.cpu()) get_accelerator().set_rng_state(_rng_state.cpu())
...@@ -172,15 +229,18 @@ class InferenceEngine(Module): ...@@ -172,15 +229,18 @@ class InferenceEngine(Module):
# todo: remove this once all the config dicts are centralized from top level pydantic config # todo: remove this once all the config dicts are centralized from top level pydantic config
def _get_model_config_generate(self, config): def _get_model_config_generate(self, config):
# this is being passed to replace_transformer_layer(config=self.user_model_config_dict) # this is being passed to replace_transformer_layer(config=self.user_model_config_dict)
self.config = getattr(self.module, self.config = getattr(self.module, 'config', None) if config.config is None else config.config
'config',
None) if config.config is None else config.config
def remove_mask_prepare_for_bloom(self): def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'): if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, '_prepare_attn_mask'): if hasattr(self.module.transformer, '_prepare_attn_mask'):
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
def build_alibi_tensor(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, 'build_alibi_tensor'):
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
def _pre_forward_hook(self, module, *inputs, **kwargs): def _pre_forward_hook(self, module, *inputs, **kwargs):
if self.use_cuda_events: if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).start() self.timers(INFERENCE_MODEL_TIMER).start()
...@@ -223,8 +283,7 @@ class InferenceEngine(Module): ...@@ -223,8 +283,7 @@ class InferenceEngine(Module):
num_ep_groups = dist.get_world_size() // moe_ep_size num_ep_groups = dist.get_world_size() // moe_ep_size
for i in range(num_ep_groups): for i in range(num_ep_groups):
ep_cnt = i * moe_ep_size ep_cnt = i * moe_ep_size
size = dist.get_world_size( size = dist.get_world_size() if moe_ep_size > dist.get_world_size() else moe_ep_size
) if moe_ep_size > dist.get_world_size() else moe_ep_size
ranks = list(range(ep_cnt, ep_cnt + size)) ranks = list(range(ep_cnt, ep_cnt + size))
_ep_group = dist.new_group(ranks) _ep_group = dist.new_group(ranks)
if dist.get_rank() in ranks: if dist.get_rank() in ranks:
...@@ -234,9 +293,7 @@ class InferenceEngine(Module): ...@@ -234,9 +293,7 @@ class InferenceEngine(Module):
num_expert_mp_groups = dist.get_world_size() // num_ep_groups num_expert_mp_groups = dist.get_world_size() // num_ep_groups
expert_mp_size = dist.get_world_size() // moe_ep_size expert_mp_size = dist.get_world_size() // moe_ep_size
for i in range(num_expert_mp_groups): for i in range(num_expert_mp_groups):
expert_mp_comm_ranks = [ expert_mp_comm_ranks = [i + nr * moe_ep_size for nr in range(expert_mp_size)]
i + nr * moe_ep_size for nr in range(expert_mp_size)
]
_expert_mp_group = dist.new_group(expert_mp_comm_ranks) _expert_mp_group = dist.new_group(expert_mp_comm_ranks)
if dist.get_rank() in expert_mp_comm_ranks: if dist.get_rank() in expert_mp_comm_ranks:
self.expert_mp_group.update({moe_ep_size: _expert_mp_group}) self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
...@@ -253,65 +310,48 @@ class InferenceEngine(Module): ...@@ -253,65 +310,48 @@ class InferenceEngine(Module):
log_dist( log_dist(
f"quantize_bits = {self.quantize_bits} " f"quantize_bits = {self.quantize_bits} "
f"mlp_extra_grouping = {self.mlp_extra_grouping}, " f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}", f"quantize_groups = {self.quantize_groups}", [0])
[0])
# TODO: remove this function and add this functionality to pydantic config checking # TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject): def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now # TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module): if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self._config.tensor_parallel.tp_size, if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
int) or self._config.tensor_parallel.tp_size < 1: raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
raise ValueError(
f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}"
)
if mpu: if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"] methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods: for method in methods:
if not hasattr(mpu, method): if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}") raise ValueError(f"mpu is missing {method}")
if self._config.checkpoint is not None and not isinstance( if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
self._config.checkpoint, raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
(str,
dict)):
raise ValueError(
f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}"
)
supported_dtypes = [None, torch.half, torch.int8, torch.float] supported_dtypes = [None, torch.half, torch.int8, torch.float]
if self._config.dtype not in supported_dtypes: if self._config.dtype not in supported_dtypes:
raise ValueError( raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
if self.injection_dict is not None and not isinstance(self.injection_dict, dict): if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError( raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
f"injection_dict must be None or a dict, got: {self.injection_dict}")
def load_model_with_checkpoint(self, r_module): def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing( self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group, mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
error_msgs = [] error_msgs = []
def load(module, state_dict, prefix): def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs) args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'): if hasattr(module, 'weight'):
if 'query_key_value' in prefix: if 'query_key_value' in prefix:
module.weight = self.mp_replace.qkv_copy( module.weight = self.mp_replace.qkv_copy(module.weight.data, state_dict[prefix + 'weight'])
module.weight.data,
state_dict[prefix + 'weight'])
else: else:
module.weight = self.mp_replace.copy(module.weight.data, module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
state_dict[prefix + 'weight'])
else: else:
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list: if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'): if hasattr(module, 'norm'):
module.norm.bias = self.mp_replace.copy(module.norm.bias, module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
state_dict[prefix + 'bias'])
else: else:
data = state_dict[prefix + 'bias'] data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name()) data = data.to(get_accelerator().current_device_name())
...@@ -331,45 +371,32 @@ class InferenceEngine(Module): ...@@ -331,45 +371,32 @@ class InferenceEngine(Module):
checking_key = prefix + name + '.' checking_key = prefix + name + '.'
if not any(checking_key in item for item in self.key_list): if not any(checking_key in item for item in self.key_list):
continue continue
if len(list(child.parameters())) > 0 and list( if len(list(child.parameters())) > 0 and list(child.parameters())[0].numel() == 0:
child.parameters())[0].numel() == 0:
if len(child.weight.ds_shape) == 1: if len(child.weight.ds_shape) == 1:
child = Normalize(dim=child.weight.ds_shape[-1], child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
dtype=child.weight.dtype,
eps=child.eps)
setattr(module, name, child) setattr(module, name, child)
load(child, self.sd, prefix + name + '.') load(child, self.sd, prefix + name + '.')
else: else:
load_module_recursive(child, load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1)
prefix if level == 0 else prefix + name + '.',
level + 1)
load_module_recursive(r_module) load_module_recursive(r_module)
def _apply_injection_policy(self, config, client_module=None): def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method. # client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint checkpoint_dir = config.checkpoint
checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint = SDLoaderFactory.get_sd_loader_json(checkpoint_dir,
checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None
self.checkpoint_engine) if checkpoint_dir is not None else None
generic_injection(self.module, generic_injection(self.module,
fp16=(config.dtype == torch.half) fp16=(config.dtype == torch.half) or (config.dtype == torch.int8),
or (config.dtype == torch.int8),
enable_cuda_graph=config.enable_cuda_graph) enable_cuda_graph=config.enable_cuda_graph)
if isinstance(self.module, torch.nn.Module): if isinstance(self.module, torch.nn.Module):
# config is our DeepSpeedInferenceConfig and self.config is the HF model config # config is our DeepSpeedInferenceConfig and self.config is the HF model config
replace_transformer_layer(client_module, replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
self.module,
checkpoint,
config,
self.config)
def _get_all_ckpt_names(self, checkpoints_path, tag): def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
tag,
mp_placeholder="*")
import glob import glob
ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files = glob.glob(ckpt_file_pattern)
...@@ -392,8 +419,7 @@ class InferenceEngine(Module): ...@@ -392,8 +419,7 @@ class InferenceEngine(Module):
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
is_pipe_parallel = isinstance(self.module, PipelineModule) is_pipe_parallel = isinstance(self.module, PipelineModule)
if is_pipe_parallel: if is_pipe_parallel:
raise RuntimeError( raise RuntimeError('pipeline parallelism is currently not supported in inference.')
'pipeline parallelism is currently not supported in inference.')
if not isinstance(load_dir, dict) and os.path.isdir(load_dir): if not isinstance(load_dir, dict) and os.path.isdir(load_dir):
if tag is None: if tag is None:
latest_path = os.path.join(load_dir, "latest") latest_path = os.path.join(load_dir, "latest")
...@@ -404,8 +430,7 @@ class InferenceEngine(Module): ...@@ -404,8 +430,7 @@ class InferenceEngine(Module):
ckpt_list = self._get_all_ckpt_names(load_dir, tag) ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else: else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)
self.checkpoint_engine)
if type(sd_loader) is list: if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu') self.sd = torch.load(sd_loader[0], map_location='cpu')
...@@ -416,19 +441,18 @@ class InferenceEngine(Module): ...@@ -416,19 +441,18 @@ class InferenceEngine(Module):
for i in range(1, len(sd_loader)): for i in range(1, len(sd_loader)):
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})") print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i], self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name())
map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys()) self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module) self.load_model_with_checkpoint(self.module)
else: else:
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size, load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size,
mp_rank, mp_rank,
is_pipe_parallel=is_pipe_parallel, is_pipe_parallel=is_pipe_parallel,
quantize=(self._config.dtype is torch.int8), quantize=(self._config.dtype is torch.int8),
quantize_groups=self.quantize_groups, quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping) mlp_extra_grouping=self.mlp_extra_grouping)
self.quantization_scales, self.quantize_merge_count = quantize_config self.quantization_scales, self.quantize_merge_count = quantize_config
...@@ -438,21 +462,20 @@ class InferenceEngine(Module): ...@@ -438,21 +462,20 @@ class InferenceEngine(Module):
old_moe_load = False old_moe_load = False
if not isinstance(checkpoint['num_experts'], list): if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True old_moe_load = True
DeepSpeedEngine.load_moe_state_dict( DeepSpeedEngine.load_moe_state_dict(load_dir,
load_dir, tag,
tag, state_dict=checkpoint[self._choose_module_key(checkpoint)],
state_dict=checkpoint[self._choose_module_key(checkpoint)], old_moe_load=old_moe_load,
old_moe_load=old_moe_load, model=self.module,
model=self.module, mpu=self.mpu,
mpu=self.mpu, checkpoint_engine=self.checkpoint_engine)
checkpoint_engine=self.checkpoint_engine)
self.module.load_state_dict(state_dict=checkpoint[self._choose_module_key(checkpoint)],
self.module.load_state_dict( strict=load_module_strict)
state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
def _choose_module_key(self, sd): def _choose_module_key(self, sd):
assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed" assert not ('module' in sd
and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed" assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
if 'module' in sd: if 'module' in sd:
return 'module' return 'module'
...@@ -465,10 +488,8 @@ class InferenceEngine(Module): ...@@ -465,10 +488,8 @@ class InferenceEngine(Module):
if False: #config.dtype is torch.int8 and self.quantization_scales is None: if False: #config.dtype is torch.int8 and self.quantization_scales is None:
quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
model, self.quantization_scales = quantizer.model_quantize(self.module, model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict,
self.injection_dict, self.quantize_bits, self.quantize_groups)
self.quantize_bits,
self.quantize_groups)
elif config.dtype == torch.half: elif config.dtype == torch.half:
self.module.half() self.module.half()
elif config.dtype == torch.bfloat16: elif config.dtype == torch.bfloat16:
...@@ -509,11 +530,10 @@ class InferenceEngine(Module): ...@@ -509,11 +530,10 @@ class InferenceEngine(Module):
assert self.model_profile_enabled, "model profiling is not enabled" assert self.model_profile_enabled, "model profiling is not enabled"
model_times = self._model_times model_times = self._model_times
if self._config.enable_cuda_graph and len(self._model_times) == 0: if self._config.enable_cuda_graph and len(self._model_times) == 0:
raise ValueError( raise ValueError("Model times are empty and cuda graph is enabled. If "
"Model times are empty and cuda graph is enabled. If " "this is a GPT-style model this combo is not supported. If this is a "
"this is a GPT-style model this combo is not supported. If this is a " "BERT-style model this is a bug, please report it. "
"BERT-style model this is a bug, please report it. " f"Model type is: {type(self.module)}")
f"Model type is: {type(self.module)}")
self._model_times = [] self._model_times = []
return model_times return model_times
...@@ -532,8 +552,7 @@ class InferenceEngine(Module): ...@@ -532,8 +552,7 @@ class InferenceEngine(Module):
for name in module.__dict__.keys(): for name in module.__dict__.keys():
sub_module = getattr(module, name) sub_module = getattr(module, name)
if self._module_match(sub_module) and hasattr(sub_module, if self._module_match(sub_module) and hasattr(sub_module, "enable_cuda_graph"):
"enable_cuda_graph"):
sub_module_cuda_graph = True sub_module_cuda_graph = True
return sub_module_cuda_graph return sub_module_cuda_graph
...@@ -546,13 +565,11 @@ class InferenceEngine(Module): ...@@ -546,13 +565,11 @@ class InferenceEngine(Module):
**kwargs: variable length keyword arguments **kwargs: variable length keyword arguments
""" """
start = None start = None
if self.model_profile_enabled and get_accelerator().device_name( if self.model_profile_enabled and get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph:
) == 'cuda' and self._config.enable_cuda_graph:
get_accelerator().synchronize() get_accelerator().synchronize()
start = time.time() start = time.time()
if get_accelerator().device_name( if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if self.cuda_graph_created: if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs) outputs = self._graph_replay(*inputs, **kwargs)
else: else:
...@@ -580,9 +597,7 @@ class InferenceEngine(Module): ...@@ -580,9 +597,7 @@ class InferenceEngine(Module):
num_beams = kwargs["num_beams"] num_beams = kwargs["num_beams"]
if num_beams > 1: if num_beams > 1:
raise NotImplementedError( raise NotImplementedError("DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"DeepSpeed does not support `num_beams` > 1, if this is important to you please " "add your request to: https://github.com/microsoft/DeepSpeed/issues/2506")
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
)
return self.module.generate(*inputs, **kwargs) return self.module.generate(*inputs, **kwargs)
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
# Copyright 2020 The Microsoft DeepSpeed Team # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
PDSH_LAUNCHER = 'pdsh' PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024 PDSH_MAX_FAN_OUT = 1024
......
# Copyright 2020 The Microsoft DeepSpeed Team # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
""" """
DeepSpeed launcher, this is similar to torch's distributed.launch but supports DeepSpeed launcher, this is similar to torch's distributed.launch but supports
additional features such as arbitrary gpu exclusion. additional features such as arbitrary gpu exclusion.
...@@ -16,6 +19,7 @@ import base64 ...@@ -16,6 +19,7 @@ import base64
import time import time
import signal import signal
import psutil import psutil
import distutils
from collections import defaultdict from collections import defaultdict
from typing import Dict from typing import Dict
from argparse import ArgumentParser, REMAINDER from argparse import ArgumentParser, REMAINDER
...@@ -52,10 +56,7 @@ def parse_args(): ...@@ -52,10 +56,7 @@ def parse_args():
help="Master node (rank 0)'s free port that needs to " help="Master node (rank 0)'s free port that needs to "
"be used for communication during distributed " "be used for communication during distributed "
"training") "training")
parser.add_argument("--world_info", parser.add_argument("--world_info", default="None", type=str, help="world info base64 encoded dictionary")
default="None",
type=str,
help="world info base64 encoded dictionary")
parser.add_argument("--module", parser.add_argument("--module",
action="store_true", action="store_true",
...@@ -68,19 +69,11 @@ def parse_args(): ...@@ -68,19 +69,11 @@ def parse_args():
help="Skip prepending the training script with " help="Skip prepending the training script with "
"'python' - just execute it directly.") "'python' - just execute it directly.")
parser.add_argument("--enable_elastic_training", parser.add_argument("--enable_elastic_training", action="store_true", help="Enable elastic training support.")
action="store_true",
help="Enable elastic training support.")
parser.add_argument("--min_elastic_nodes", parser.add_argument("--min_elastic_nodes", type=int, default=-1, help="Min number of nodes in elastic training.")
type=int,
default=-1,
help="Min number of nodes in elastic training.")
parser.add_argument("--max_elastic_nodes", parser.add_argument("--max_elastic_nodes", type=int, default=-1, help="Max number of nodes in elastic training.")
type=int,
default=-1,
help="Max number of nodes in elastic training.")
parser.add_argument("--no_local_rank", parser.add_argument("--no_local_rank",
action="store_true", action="store_true",
...@@ -92,11 +85,22 @@ def parse_args(): ...@@ -92,11 +85,22 @@ def parse_args():
default=0, default=0,
help="main launching process pid, for internal pid tracking") help="main launching process pid, for internal pid tracking")
parser.add_argument( parser.add_argument("--enable_each_rank_log",
"--enable_each_rank_log", default="None",
default="None", type=str,
type=str, help="redirect the stdout and stderr from each rank into different log files")
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument("--bind_cores_to_rank",
action="store_true",
help="Bind each rank to different cores of the host. "
"This improves host efficiency especially for CPU backend")
parser.add_argument("--bind_core_list",
type=str,
default=None,
help="List of cores to bind to with comma separated list of "
"numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
"specified, all cores on system would be used rank binding")
# positional # positional
parser.add_argument("training_script", parser.add_argument("training_script",
...@@ -126,6 +130,89 @@ def terminate_process_tree(pid): ...@@ -126,6 +130,89 @@ def terminate_process_tree(pid):
p.kill() p.kill()
def parse_range(rng):
try:
value = int(rng)
return range(value, value + 1)
except ValueError:
# value is not a single number
parts = rng.split('-')
if len(parts) != 2:
raise ValueError("Bad range: '%s', range must be either a number or two number separated by dash" %
(rng, ))
start = int(parts[0])
end = int(parts[1])
if start > end:
raise ValueError("Bad range: '%s', range end must larger than or equal to start" % (rng, ))
return range(start, end + 1)
# parse comma and dash separated range list into list
# i.e. "0,2-4,6" --> [0, 2, 3, 4, 6]
# rules:
# 1. Range list numser be comma sepeaated, each item are either a single number,
# or a range marked by two numbers (both number are included in the range)
# 2. Sub ranges must be in ascend order and not overlap with each other
# 3. No space in the range expression
def parse_range_list(range_str):
number_list = []
last = -1
range_list = range_str.split(',')
for sub_range in range_list:
sub_number_list = parse_range(sub_range)
if sub_number_list[0] <= last:
raise ValueError(
"Bad range: '%s', sub ranges must not overlap with each other and should be in ascend order" %
(range_str, ))
last = sub_number_list[-1]
number_list.extend(sub_number_list)
return number_list
# return a list of list for cores to numa mapping
# [
# [ cores for numa 0 ]
# [ cores belong to numa 1 ]
# ...
# ]
def get_numa_cores():
ret = []
output = subprocess.check_output(['numactl', '--hardware']).decode("utf-8")
lines = output.split('\n')
for line in lines:
if line.startswith('available:'):
num_numas = int(line.split(' ')[1])
break
for numa in range(num_numas):
for line in lines:
if line.startswith(f'node {numa} cpus:'):
cores = line.split(' ')[3:]
ret.append([int(core) for core in cores])
return ret
def check_for_numactl_pkg():
libs = dict(
dpkg=["-l", "numactl", "apt"],
pacman=["-Q", "numactl", "pacman"],
rpm=["-q", "numactl", "yum"],
)
found = False
for pkgmgr, data in libs.items():
flag, lib, tool = data
path = distutils.spawn.find_executable(pkgmgr)
if path is not None:
cmd = f"{pkgmgr} {flag} {lib}"
result = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
if result.wait() == 0:
found = True
else:
print(f"please install the {lib} package with {tool}")
break
return found
def main(): def main():
args = parse_args() args = parse_args()
current_env = os.environ.copy() current_env = os.environ.copy()
...@@ -145,9 +232,7 @@ def main(): ...@@ -145,9 +232,7 @@ def main():
local_node = node_list[args.node_rank] local_node = node_list[args.node_rank]
local_gpu_ids = world_info[local_node] local_gpu_ids = world_info[local_node]
num_local_procs = len(local_gpu_ids) num_local_procs = len(local_gpu_ids)
logger.info( logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}")
f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}"
)
global_rank_mapping = defaultdict(list) global_rank_mapping = defaultdict(list)
curr_global_rank = 0 curr_global_rank = 0
...@@ -193,8 +278,7 @@ def main(): ...@@ -193,8 +278,7 @@ def main():
lines = file.readlines() lines = file.readlines()
lines = [line.rstrip() for line in lines] lines = [line.rstrip() for line in lines]
for line in lines: for line in lines:
if line.startswith('export FC_TASKROLE_NAME') or line.startswith( if line.startswith('export FC_TASKROLE_NAME') or line.startswith('export FC_TASK_INDEX'):
'export FC_TASK_INDEX'):
key_val = line.split()[1] key_val = line.split()[1]
key, val = key_val.split('=') key, val = key_val.split('=')
current_env[key] = val current_env[key] = val
...@@ -206,17 +290,13 @@ def main(): ...@@ -206,17 +290,13 @@ def main():
if args.enable_each_rank_log != "None": if args.enable_each_rank_log != "None":
# prepare the log path and the file name prefix # prepare the log path and the file name prefix
if os.path.isfile(args.enable_each_rank_log): if os.path.isfile(args.enable_each_rank_log):
raise ValueError( raise ValueError(f"{args.enable_each_rank_log} should not be a file, it should be a directory.")
f"{args.enable_each_rank_log} should not be a file, it should be a directory."
)
if not os.path.exists(args.enable_each_rank_log): if not os.path.exists(args.enable_each_rank_log):
try: try:
os.makedirs(args.enable_each_rank_log) os.makedirs(args.enable_each_rank_log)
except Exception as e: except Exception as e:
print(e) print(e)
raise ValueError( raise ValueError(f"unable to create directory {args.enable_each_rank_log} for each rank log.")
f"unable to create directory {args.enable_each_rank_log} for each rank log."
)
log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime()) log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime())
for local_rank in range(0, num_local_procs): for local_rank in range(0, num_local_procs):
...@@ -227,8 +307,43 @@ def main(): ...@@ -227,8 +307,43 @@ def main():
# spawn the processes # spawn the processes
cmd = [] cmd = []
if args.bind_cores_to_rank:
check_for_numactl_pkg()
if 'KMP_AFFINITY' in os.environ.keys():
raise ValueError("Environment variable KMP_AFFINITY conflicts with numactl "
"because it interfere with how many CPU cores numactl can set. "
"Unset KMP_AFFINITY before launching deepspeed.\n\n"
"\t$ unset KMP_AFFINITY\n"
"\t$ deepspeed <deepspeed command parameters>")
if args.bind_core_list != None:
core_list = parse_range_list(args.bind_core_list)
total_cores = len(core_list)
else:
total_cores = psutil.cpu_count(logical=False)
core_list = range(total_cores)
cores_per_rank = total_cores // num_local_procs
assert cores_per_rank >= 1, "At least one core needs to be assigned to each rank"
core_list_for_rank = core_list[cores_per_rank * local_rank:cores_per_rank * (local_rank + 1)]
current_env["OMP_NUM_THREADS"] = f"{cores_per_rank}"
cmd.append("numactl")
# check if all cores belong to same numa, if true, bind process to that numa domain with -m parameter
numa_cores = get_numa_cores()
num_numas = len(numa_cores)
for i in range(num_numas):
if set(core_list_for_rank) <= set(numa_cores[i]):
cmd.append("-m")
cmd.append(f"{i}")
break
cmd.append("-C")
core_list_str = f"{core_list_for_rank[0]}"
for core_id in core_list_for_rank[1:]:
core_list_str = f"{core_list_str},{core_id}"
cmd.append(f"{core_list_str}")
if not args.no_python: if not args.no_python:
cmd = [sys.executable, "-u"] cmd.append(sys.executable)
cmd.append("-u")
if args.module: if args.module:
cmd.append("-m") cmd.append("-m")
else: else:
...@@ -242,13 +357,9 @@ def main(): ...@@ -242,13 +357,9 @@ def main():
cmd += args.training_script_args cmd += args.training_script_args
if args.enable_each_rank_log != "None": if args.enable_each_rank_log != "None":
log_file = os.path.join(args.enable_each_rank_log, log_file = os.path.join(args.enable_each_rank_log, f"{log_name_prefix}_rank{dist_rank}.log")
f"{log_name_prefix}_rank{dist_rank}.log")
log_fd = open(log_file, 'w') log_fd = open(log_file, 'w')
process = subprocess.Popen(cmd, process = subprocess.Popen(cmd, env=current_env, stdout=log_fd, stderr=log_fd)
env=current_env,
stdout=log_fd,
stderr=log_fd)
else: else:
process = subprocess.Popen(cmd, env=current_env) process = subprocess.Popen(cmd, env=current_env)
...@@ -264,7 +375,7 @@ def main(): ...@@ -264,7 +375,7 @@ def main():
args.min_elastic_nodes = 1 args.min_elastic_nodes = 1
if args.max_elastic_nodes == -1: if args.max_elastic_nodes == -1:
args.max_elastic_nodes = args.nnodes args.max_elastic_nodes = args.nnodes
assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive" assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0, "Max and Min nodes should be positive"
current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1) current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
...@@ -287,8 +398,7 @@ def main(): ...@@ -287,8 +398,7 @@ def main():
# Creating config for rendezvous class # Creating config for rendezvous class
rdzv_parameters = RendezvousParameters(backend='c10d', rdzv_parameters = RendezvousParameters(backend='c10d',
endpoint=args.master_addr + ":" + endpoint=args.master_addr + ":" + str(args.master_port),
str(args.master_port),
run_id=run_id, run_id=run_id,
min_nodes=args.min_elastic_nodes, min_nodes=args.min_elastic_nodes,
max_nodes=args.max_elastic_nodes, max_nodes=args.max_elastic_nodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment