Commit 67ea635f authored by aiss's avatar aiss
Browse files

push dsv0.8.2 version

parent 1b2721ad
Pipeline #201 failed with stages
in 0 seconds
'''Copyright The Microsoft DeepSpeed Team'''
#########################################
# Compression Methods
# It has several sub-components
# #########################################
COMPRESSION_TRAINING = "compression_training"
SHARED_PARAMETERS = "shared_parameters"
DIFFERENT_GROUPS = "different_groups"
TECHNIQUE_ENABLED = "enabled"
TECHNIQUE_SCHEDULE_OFFSET = "schedule_offset"
DIFFERENT_GROUPS_PARAMETERS = "params"
DIFFERENT_GROUPS_MODULE_SCOPE = "modules"
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT = "*"
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE = "related_modules"
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT = None
# COMPRESSION_TRAINING_ENABLED = "enabled"
# COMPRESSION_TRAINING_ENABLED_DEFAULT = False
####
# Layer Reduction
####
LAYER_REDUCTION = "layer_reduction"
LAYER_REDUCTION_ENABLED = "enabled"
LAYER_REDUCTION_ENABLED_DEFAULT = False
KEEP_NUMBER_LAYER = "keep_number_layer"
MODULE_NAME_PREFIX = "module_name_prefix"
TEACHER_LAYER = "teacher_layer"
OTHER_MODULE_NAME = "other_module_name"
####
# Weight Quantzation
####
WEIGHT_QUANTIZATION = "weight_quantization"
WEIGHT_QUANTIZATION_PERIOD = "quantization_period"
WEIGHT_QUANTIZATION_PERIOD_DEFAULT = 1
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED = "quantize_weight_in_forward"
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT = False
WEIGHT_QUANTIZE_ENABLED = TECHNIQUE_ENABLED
WEIGHT_QUANTIZE_ENABLED_DEFAULT = False
WEIGHT_QUANTIZE_KERNEL = "quantizer_kernel"
WEIGHT_QUANTIZE_KERNEL_DEFAULT = False
WEIGHT_QUANTIZE_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT = 0
WEIGHT_QUANTIZE_GROUPS = "quantize_groups"
WEIGHT_QUANTIZE_GROUPS_DEFAULT = 1
WEIGHT_QUANTIZE_VERBOSE = "quantize_verbose"
WEIGHT_QUANTIZE_VERBOSE_DEFAULT = False
WEIGHT_QUANTIZE_TYPE = "quantization_type"
WEIGHT_QUANTIZE_TYPE_DEFAULT = "symmetric"
WEIGHT_QUANTIZE_SYMMETRIC = "symmetric"
WEIGHT_QUANTIZE_ASYMMETRIC = "asymmetric"
WEIGHT_QUANTIZE_ROUNDING = "rounding"
WEIGHT_QUANTIZE_ROUNDING_DEFAULT = "nearest"
WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING = "stochastic"
WEIGHT_QUANTIZE_NEAREST_ROUNDING = "nearest"
# maybe deleted for a cleaner version
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE = "fp16_mixed_quantize"
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED = "enabled"
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT = False
WEIGHT_QUANTIZE_CHANGE_RATIO = "quantize_change_ratio"
WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT = 0.001
WEIGHT_QUANTIZE_START_BITS = "start_bits"
WEIGHT_QUANTIZE_TARGET_BITS = "target_bits"
###
# Activation Quantization
###
ACTIVATION_QUANTIZATION = "activation_quantization"
ACTIVATION_QUANTIZATION_ENABLED = TECHNIQUE_ENABLED
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT = False
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT = 1000
ACTIVATION_QUANTIZE_TYPE = "quantization_type"
ACTIVATION_QUANTIZE_TYPE_DEFAULT = "symmetric"
ACTIVATION_QUANTIZE_SYMMETRIC = "symmetric"
ACTIVATION_QUANTIZE_ASYMMETRIC = "asymmetric"
ACTIVATION_QUANTIZE_RANGE = 'range_calibration'
ACTIVATION_QUANTIZE_RANGE_DEFAULT = 'dynamic'
ACTIVATION_QUANTIZE_RANGE_STATIC = 'static'
ACTIVATION_QUANTIZE_RANGE_DYNAMIC = 'dynamic'
ACTIVATION_QUANTIZE_BITS = "bits"
###
# Sparse Pruning
###
SPARSE_PRUNING = "sparse_pruning"
SPARSE_PRUNING_ENABLED = TECHNIQUE_ENABLED
SPARSE_PRUNING_ENABLED_DEFAULT = False
SPARSE_PRUNING_METHOD = "method"
SPARSE_PRUNING_METHOD_DEFAULT = "l1"
SPARSE_PRUNING_METHOD_L1 = "l1"
SPARSE_PRUNING_METHOD_TOPK = "topk"
SPARSE_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
SPARSE_PRUNING_DENSE_RATIO = "dense_ratio"
###
# Row Pruning
###
ROW_PRUNING = "row_pruning"
ROW_PRUNING_ENABLED = TECHNIQUE_ENABLED
ROW_PRUNING_ENABLED_DEFAULT = False
ROW_PRUNING_METHOD = "method"
ROW_PRUNING_METHOD_DEFAULT = "l1"
ROW_PRUNING_METHOD_L1 = "l1"
ROW_PRUNING_METHOD_TOPK = "topk"
ROW_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
ROW_PRUNING_DENSE_RATIO = "dense_ratio"
###
# Head Pruning
###
HEAD_PRUNING = "head_pruning"
HEAD_PRUNING_ENABLED = TECHNIQUE_ENABLED
HEAD_PRUNING_ENABLED_DEFAULT = False
HEAD_PRUNING_METHOD = "method"
HEAD_PRUNING_METHOD_DEFAULT = "topk"
HEAD_PRUNING_METHOD_L1 = "l1"
HEAD_PRUNING_METHOD_TOPK = "topk"
HEAD_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
HEAD_PRUNING_NUM_HEADS = "num_heads"
HEAD_PRUNING_DENSE_RATIO = "dense_ratio"
###
# Channel Pruning
###
CHANNEL_PRUNING = "channel_pruning"
CHANNEL_PRUNING_ENABLED = TECHNIQUE_ENABLED
CHANNEL_PRUNING_ENABLED_DEFAULT = False
CHANNEL_PRUNING_METHOD = "method"
CHANNEL_PRUNING_METHOD_DEFAULT = "l1"
CHANNEL_PRUNING_METHOD_L1 = "l1"
CHANNEL_PRUNING_METHOD_TOPK = "topk"
CHANNEL_PRUNING_SCHEDULE_OFFSET = TECHNIQUE_SCHEDULE_OFFSET
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT = 1000
CHANNEL_PRUNING_DENSE_RATIO = "dense_ratio"
'''Copyright The Microsoft DeepSpeed Team'''
import torch
from .basic_layer import Embedding_Compress, LinearLayer_Compress, Conv2dLayer_Compress, BNLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress
from .constants import *
def recursive_getattr(model, module_name):
"""
Recursively get the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to get the attribute from.
module_name (`str`)
The name of the module to get the attribute from.
"""
split_list = module_name.split('.')
output = model
for name in split_list:
output = getattr(output, name)
return output
def recursive_setattr(model, module_name, module):
"""
Recursively set the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to set the attribute in.
module_name (`str`)
The name of the module to set the attribute in.
module (`torch.nn.Module`)
The module to set the attribute to.
"""
split_list = module_name.split('.')
output = model
for name in split_list[:-1]:
output = getattr(output, name)
output.__setattr__(split_list[-1], module)
def module_replacement(model, module_name, compression_technique=None, mpu=None):
"""
Replace a module with a new module.
Args:
model (`torch.nn.Module`)
The model to replace the module in.
module_name (`str`)
The name of the module to replace.
compression_technique (`str`)
The compression technique to use for the new module.
"""
# Get the old module
old_module = recursive_getattr(model, module_name)
need_bias = False
if hasattr(old_module, 'bias') and old_module.bias is not None:
need_bias = True
# Initialize the new module
if isinstance(old_module,
LinearLayer_Compress) or isinstance(old_module,
torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress):
new_module = old_module
else:
new_module = LinearLayer_Compress(old_module.in_features,
old_module.out_features,
bias=need_bias).to(
device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module,
Conv2dLayer_Compress) or isinstance(old_module,
torch.nn.Conv2d):
if isinstance(old_module, Conv2dLayer_Compress):
new_module = old_module
else:
new_module = Conv2dLayer_Compress(old_module.in_channels, old_module.out_channels, old_module.kernel_size, old_module.stride, old_module.padding, \
old_module.dilation, old_module.groups, need_bias, \
old_module.padding_mode).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module, torch.nn.BatchNorm2d):
new_module = BNLayer_Compress(old_module.num_features,
old_module.eps,
old_module.momentum,
old_module.affine,
old_module.track_running_stats).to(
old_module.weight.device,
old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
new_module.running_mean.data = old_module.running_mean.data
new_module.running_var.data = old_module.running_var.data
elif isinstance(old_module,
Embedding_Compress) or isinstance(old_module,
torch.nn.Embedding):
if isinstance(old_module, Embedding_Compress):
new_module = old_module
else:
new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
elif mpu is not None and (isinstance(old_module,
ColumnParallelLinear_Compress)
or isinstance(old_module,
mpu.ColumnParallelLinear)):
if isinstance(old_module, ColumnParallelLinear_Compress):
new_module = old_module
else:
new_module = ColumnParallelLinear_Compress(
mpu,
old_module.input_size,
old_module.output_size,
gather_output=old_module.gather_output,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif mpu is not None and (isinstance(old_module,
RowParallelLinear_Compress)
or isinstance(old_module,
mpu.RowParallelLinear)):
if isinstance(old_module, RowParallelLinear_Compress):
new_module = old_module
else:
new_module = RowParallelLinear_Compress(
mpu,
old_module.input_size,
old_module.output_size,
input_is_parallel=old_module.input_is_parallel,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
else:
new_module = None
if compression_technique is not None:
for k, v in compression_technique.items():
if k == SPARSE_PRUNING:
if v[SPARSE_PRUNING_ENABLED]:
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO],
v[SPARSE_PRUNING_METHOD])
elif k == ROW_PRUNING:
if v[ROW_PRUNING_ENABLED]:
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO],
v[ROW_PRUNING_METHOD])
elif k == HEAD_PRUNING:
if v[HEAD_PRUNING_ENABLED]:
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO],
v[HEAD_PRUNING_METHOD],
v[HEAD_PRUNING_NUM_HEADS])
elif k == ACTIVATION_QUANTIZATION:
if v[ACTIVATION_QUANTIZATION_ENABLED]:
new_module.enable_activation_quantization(
v[ACTIVATION_QUANTIZE_BITS],
v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
elif k == WEIGHT_QUANTIZATION:
if v[WEIGHT_QUANTIZE_ENABLED]:
new_module.enable_weight_quantization(
v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_TYPE],
v[WEIGHT_QUANTIZE_GROUPS])
elif k == CHANNEL_PRUNING:
if v[CHANNEL_PRUNING_ENABLED]:
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO],
v[CHANNEL_PRUNING_METHOD])
else:
raise NotImplementedError(
'Compression technique {} is not implemented'.format(k))
# Replace the old module with the new one
recursive_setattr(model, module_name, new_module)
def is_module_compressible(module, mpu=None):
ret = isinstance(module, torch.nn.Linear) or \
isinstance(module, torch.nn.Conv2d) or \
isinstance(module, torch.nn.Embedding) or \
isinstance(module, torch.nn.BatchNorm2d)
if mpu is not None:
ret = ret or isinstance(module,
mpu.RowParallelLinear) or isinstance(
module,
mpu.ColumnParallelLinear)
return ret
def compression_preparation(model, compression_techinique_list, mpu):
"""
Prepare the compression techniques of a model.
Args:
model (`torch.nn.Module`)
The model to prepare the compression techniques of.
compression_techinique_list (`list`)
The list of compression techniques to prepare the model to.
list[]
"""
# Here we first replace all module with our linear wrapper
for module_name, module in model.named_modules():
if is_module_compressible(module, mpu):
module_replacement(model, module_name, mpu=mpu)
for module_name_lists, _, compression_technique in compression_techinique_list:
for mnl in module_name_lists:
for module_name in mnl:
module_replacement(model, module_name, compression_technique)
return model
def fix_compression(model,
module_name,
compression_technique,
mask=None,
dim_reduction=False):
"""
Fix the compression technique of a module.
Args:
model (`torch.nn.Module`)
The model to fix the compression technique of.
module_name (`str`)
The name of the module to fix the compression technique of.
compression_technique (`str`)
The compression technique to fix the module to.
"""
# Here we can make things much simpler by just replacing the module
module = recursive_getattr(model, module_name)
for k, v in compression_technique.items():
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[
WEIGHT_QUANTIZE_ENABLED]:
return module.fix_weight_quantization()
elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
return module.fix_sparse_pruning_helper()
elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
return module.fix_head_pruning_helper(mask,
v[HEAD_PRUNING_NUM_HEADS],
dim_reduction=dim_reduction)
elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
def convert_conv1d_to_linear(model, convert_type):
'''
This is a help function to convert conv1d to linear (e.g., convert GPT2 from HF)
'''
if hasattr(model, 'module'):
c_model = model.module
else:
c_model = model
for name, module in c_model.named_modules():
if isinstance(module, convert_type):
old_module = recursive_getattr(c_model, name)
new_module = torch.nn.Linear(
old_module.weight.data.size(0),
old_module.weight.data.size(1),
bias=True if old_module.bias is not None else False)
new_module.weight.data = old_module.weight.data.t().contiguous()
if new_module.bias is not None:
new_module.bias.data = old_module.bias.data.view(-1)
recursive_setattr(c_model, name, new_module)
return model
'''Copyright The Microsoft DeepSpeed Team'''
from .compress import get_module_name
from .constants import *
from .helper import recursive_getattr
from deepspeed.utils import logger
class compression_scheduler():
'''
Used to schedule different compression methods
'''
def __init__(self, model, compression_config):
self.model = model
self.compression_config = compression_config
self.make_init()
self.training_steps = 0
self.weight_quantization_enabled = False
self.verbose = {
WEIGHT_QUANTIZATION: False,
ACTIVATION_QUANTIZATION: False,
SPARSE_PRUNING: False,
HEAD_PRUNING: False,
ROW_PRUNING: False,
CHANNEL_PRUNING: False
}
def make_init(self):
self.different_compression_methods = {}
for method, method_content in self.compression_config.items():
if LAYER_REDUCTION in method:
continue
self.different_compression_methods[method] = {
TECHNIQUE_ENABLED: False,
SHARED_PARAMETERS: None,
DIFFERENT_GROUPS: []
}
exist_module_name = set()
shared_parameters = method_content[SHARED_PARAMETERS]
self.different_compression_methods[method][
TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
self.different_compression_methods[method][
SHARED_PARAMETERS] = shared_parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
module_name_list = []
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, self.model, key_word, exist_module_name, verbose=False)
module_name_list.extend(module_name)
if module_name_list:
self.different_compression_methods[method][DIFFERENT_GROUPS].append([
group_name,
module_name_list,
method_parameters.copy().pop('params')
])
def check_weight_quantization(self):
# check weight quantization
wq = self.different_compression_methods[WEIGHT_QUANTIZATION]
if not wq[TECHNIQUE_ENABLED]:
return
else:
shared_parameters = wq[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
module.weight_quantization_enabled = True
if not self.verbose[WEIGHT_QUANTIZATION]:
logger.info(
f'Weight quantization is enabled at step {self.training_steps}')
self.weight_quantization_enabled = True
self.verbose[WEIGHT_QUANTIZATION] = True
def check_activation_quantization(self):
# check activation quantization
aq = self.different_compression_methods[ACTIVATION_QUANTIZATION]
if not aq[TECHNIQUE_ENABLED]:
return
else:
shared_parameters = aq[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
module.activation_quantization_enabled = True
if not self.verbose[ACTIVATION_QUANTIZATION]:
logger.info(
f'Activation quantization is enabled at step {self.training_steps}'
)
self.verbose[ACTIVATION_QUANTIZATION] = True
def check_sparse_pruning(self):
# check sparse pruning
sp = self.different_compression_methods[SPARSE_PRUNING]
if not sp[TECHNIQUE_ENABLED]:
return
else:
shared_parameters = sp[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
module.sparse_pruning_enabled = True
if not self.verbose[SPARSE_PRUNING]:
logger.info(
f'Sparse pruning is enabled at step {self.training_steps}')
self.verbose[SPARSE_PRUNING] = True
def check_head_pruning(self):
# check head pruning
hp = self.different_compression_methods[HEAD_PRUNING]
if not hp[TECHNIQUE_ENABLED]:
return
else:
shared_parameters = hp[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
module.head_pruning_enabled = True
if not self.verbose[HEAD_PRUNING]:
logger.info(f'Head pruning is enabled at step {self.training_steps}')
self.verbose[HEAD_PRUNING] = True
def check_row_pruning(self):
# check row pruning
rp = self.different_compression_methods[ROW_PRUNING]
if not rp[TECHNIQUE_ENABLED]:
return
else:
shared_parameters = rp[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
module.row_pruning_enabled = True
if not self.verbose[ROW_PRUNING]:
logger.info(f'Row pruning is enabled at step {self.training_steps}')
self.verbose[ROW_PRUNING] = True
def check_channel_pruning(self):
# check channel pruning
cp = self.different_compression_methods[CHANNEL_PRUNING]
if not cp[TECHNIQUE_ENABLED]:
return
else:
shared_parameters = cp[SHARED_PARAMETERS]
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]:
for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]:
for module_name in module_name_list:
module = recursive_getattr(self.model, module_name)
module.channel_pruning_enabled = True
if not self.verbose[CHANNEL_PRUNING]:
logger.info(
f'Channel pruning is enabled at step {self.training_steps}')
self.verbose[CHANNEL_PRUNING] = True
def check_all_modules(self):
# check all different compression methods we have
self.check_weight_quantization()
self.check_activation_quantization()
self.check_sparse_pruning()
self.check_head_pruning()
self.check_row_pruning()
self.check_channel_pruning()
def step(self, step_zero_check=False):
if not step_zero_check:
self.training_steps += 1
self.check_all_modules()
'''Copyright The Microsoft DeepSpeed Team'''
import torch
from torch import autograd
import math
class TopKBinarizer(autograd.Function):
"""
Top-k Binarizer.
Computes a binary mask M from a real value matrix S such that `M_{i,j} = 1` if and only if `S_{i,j}`
is among the k% highest values of S.
Implementation is inspired from:
https://github.com/yaozhewei/MLPruning
"""
@staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
"""
Args:
inputs (`torch.FloatTensor`)
The input matrix from which the binarizer computes the binary mask.
threshold (`float`)
The percentage of weights to keep (the rest is pruned).
`threshold` is a float between 0 and 1.
sigmoid (`bool`)
Whether to apply a sigmoid on the threshold
Returns:
mask (`torch.FloatTensor`)
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
retained, 0 - the associated weight is pruned).
"""
# Get the subnetwork by sorting the inputs and using the top threshold
if sigmoid:
threshold = torch.sigmoid(threshold).item()
ctx.sigmoid = sigmoid
mask = inputs.clone()
_, idx = inputs.flatten().sort(descending=True)
j = math.ceil(threshold * inputs.numel())
# flat_out and mask access the same memory.
flat_out = mask.flatten()
flat_out[idx[j:]] = 0.
flat_out[idx[:j]] = 1.
ctx.save_for_backward(mask)
return mask
@staticmethod
def backward(ctx, gradOutput):
mask, = ctx.saved_tensors
if ctx.sigmoid:
return gradOutput.clone(), ((gradOutput * mask).sum()).view(-1), None
else:
return gradOutput.clone(), None, None
class SymQuantizer(torch.autograd.Function):
"""
Symmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int, >=4)
Number of bits to use for quantization
min_value/max_vlue (torch.FloatTensor)
Used for static activation quantization
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None
and max_value is None) or (min_value is not None
and max_value is not None and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
input = input.reshape(num_groups, -1)
max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1)
else:
max_input = torch.max(min_value.abs(), max_value).view(-1)
scale = 2 * max_input / q_range
output = (input / scale).round().clamp(-q_range // 2, q_range // 2 - 1) * scale
output = output.reshape(input_shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
class AsymQuantizer(torch.autograd.Function):
"""
Asymmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int, >=4)
Number of bits to use for quantization
min_value/max_vlue (torch.FloatTensor)
Used for static activation quantization
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None
and max_value is None) or (min_value is not None
and max_value is not None and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
input = input.reshape(num_groups, -1)
min_value = input.amin(dim=-1, keepdim=True)
max_value = input.amax(dim=-1, keepdim=True)
scale = (max_value - min_value) / q_range
zero_point = (min_value / scale).round() * scale
output = (
(input - zero_point) / scale).round().clamp(0,
q_range - 1) * scale + zero_point
output = output.reshape(input_shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
class TernaryQuantizer(torch.autograd.Function):
"""
Ternary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int)
Dummy variable
min_value/max_vlue (torch.FloatTensor)
Used for static activation quantization; for now they are dummy variable
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None and max_value is None)
input_flat = input.reshape(num_groups, -1)
n = input_flat.shape[1]
m = input_flat.norm(p=1, dim=1).div(n)
thres = (0.7 * m).view(-1, 1)
pos = (input_flat > thres).type(input.type())
neg = (input_flat < -thres).type(input.type())
mask = (input_flat.abs() > thres).type(input.type())
alpha = ((mask * input_flat).abs().sum(dim=1) / mask.sum(dim=1)).view(-1, 1)
output = alpha * pos - alpha * neg
output = output.reshape(input.shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
class BinaryQuantizer(torch.autograd.Function):
"""
Binary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
Args:
inputs (`torch.FloatTensor`)
The input which needs to be quantized
num_bits (int)
Dummy variable
min_value/max_vlue (torch.FloatTensor)
Used for static activation quantization; for now they are dummy variable
num_groups (int)
How many groups to partition the quantization into
Returns:
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None and max_value is None)
input_flat = input.reshape(num_groups, -1)
n = input_flat.shape[1]
m = input_flat.norm(p=1, dim=1, keepdim=True).div(n)
output = input_flat.sign().mul(m)
output = output.reshape(input.shape).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input, None, None, None, None
......@@ -14,6 +14,5 @@ TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)
INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
'''Copyright The Microsoft DeepSpeed Team'''
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
from .utils import is_torch_elastic_compatible
from .constants import ENABLED, ENABLED_DEFAULT, ELASTICITY
if is_torch_elastic_compatible():
from .elastic_agent import DSElasticAgent
......@@ -79,6 +79,7 @@ class ElasticityConfig:
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError(
"Elasticity min/max gpus must be > 0, "
......@@ -88,6 +89,20 @@ class ElasticityConfig:
"Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE,
MODEL_PARLLEL_SIZE_DEFAULT)
if self.model_parallel_size < 1:
raise ElasticityConfigError(
"Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")
self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE,
NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1:
raise ElasticityConfigError(
"Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0:
raise ElasticityConfigError(
......
......@@ -27,7 +27,7 @@ Elasticity should be enabled as:
ELASTICITY = 'elasticity'
# Current elasticity version
LATEST_ELASTICITY_VERSION = 0.1
LATEST_ELASTICITY_VERSION = 0.2
ENABLED = 'enabled'
ENABLED_DEFAULT = False
......@@ -46,6 +46,12 @@ MIN_GPUS_DEFAULT = 1
MAX_GPUS = 'max_gpus'
MAX_GPUS_DEFAULT = 10000
NUM_GPUS_PER_NODE = 'num_gpus_per_node'
NUM_GPUS_PER_NODE_DEFAULT = 1
MODEL_PARLLEL_SIZE = "model_parallel_size"
MODEL_PARLLEL_SIZE_DEFAULT = 1
# Minimum running time (minutes) before the scheduler will scale us, 0 implies it's unknown
MIN_TIME = "min_time"
MIN_TIME_DEFAULT = 0
......
'''Copyright The Microsoft DeepSpeed Team'''
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
from torch.distributed.elastic.metrics import put_metric
from torch.distributed.elastic.agent.server.api import (
RunResult,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from torch.distributed import Store
import time
import os
from torch.distributed.elastic.multiprocessing import start_processes
from torch.distributed.elastic.utils import macros
import shutil
import copy
from contextlib import closing
import subprocess
class DSElasticAgent(LocalElasticAgent):
def __init__(
self,
spec: WorkerSpec,
env: Dict,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_dir: Optional[str] = None,
):
super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
self.ds_env = env
@staticmethod
def _set_master_addr_port(store: Store,
master_addr: Optional[str],
master_port: Optional[int]):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]
if master_addr is None:
# master_addr = _get_fq_hostname()
result = subprocess.check_output("hostname -I", shell=True)
master_addr = result.decode('utf-8').split()[0]
store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
master_addr, master_port = super()._get_master_addr_port(store)
restart_count = spec.max_restarts - self._remaining_restarts
use_agent_store = spec.rdzv_handler.get_backend() == "static"
args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
for worker in worker_group.workers:
local_rank = worker.local_rank
worker_env_ds = copy.deepcopy(self.ds_env)
worker_env_elastic = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": str(master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING",
str(1)),
}
worker_env_ds.update(worker_env_elastic)
if "OMP_NUM_THREADS" in os.environ:
worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
envs[local_rank] = worker_env_ds
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)
# scaling events do not count towards restarts (gets same attempt #)
# remove existing log dir if this restart is due to a scaling event
attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)
assert spec.entrypoint is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
log_dir=attempt_log_dir,
start_method=self._start_method,
redirects=spec.redirects,
tee=spec.tee,
)
return self._pcontext.pids()
def _invoke_run(self, role: str = "default") -> RunResult:
# NOTE: currently only works for a single role
spec = self._worker_group.spec
role = spec.role
log.info(
f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler
participants = rdzv_handler._state_holder.state.participants
while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state
expire_time = datetime.utcnow() - (
rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
_dead_nodes = [
node for node,
last_heartbeat in
rdzv_handler._state_holder.state.last_heartbeats.items()
if last_heartbeat < expire_time
]
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
self._exit_barrier()
return run_result
elif state in {
WorkerState.UNHEALTHY,
WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
self._remaining_restarts -= 1
# rdzv_handler._state_holder.state.restart = False
self._restart_workers(self._worker_group)
participants = rdzv_handler._state_holder.state.participants
else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
log.info(f"[{role}] Detected {num_nodes_waiting} "
f"new nodes from group_rank={group_rank}; "
f"will restart worker group")
self._restart_workers(self._worker_group)
participants = rdzv_handler._state_holder.state.participants
else:
raise Exception(f"[{role}] Worker group in {state.name} state")
......@@ -2,17 +2,15 @@
Copyright 2020 The Microsoft DeepSpeed Team
"""
import os
import re
import json
import numpy as np
import math
from packaging import version as pkg_version
from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \
ElasticityIncompatibleWorldSize
from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \
MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG
MINIMUM_DEEPSPEED_VERSION, DEEPSPEED_ELASTICITY_CONFIG
from ..git_version_info import version as __version__
from ..utils import logger
......@@ -93,7 +91,6 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
valid_gpus.append(i)
valid_gpus = set(valid_gpus)
valid_gpus = sorted(list(valid_gpus))
logger.info(f"Valid GPUs: {valid_gpus}")
return valid_gpus
......@@ -173,6 +170,70 @@ def _get_compatible_gpus_v01(micro_batches,
return final_batch_size, valid_gpus
def _get_compatible_gpus_v02(micro_batches,
max_acceptable_batch_size,
current_num_gpus,
min_gpus=None,
max_gpus=None,
prefer_larger=True,
num_gpus_per_node=1,
model_parallel_size=1):
'''
Returns:
final_batch_size
valid_gpus
micro-batch size
'''
if num_gpus_per_node % model_parallel_size != 0:
raise ElasticityError(
f"In Elasticity v0.2, number of GPUs per node:" \
f"{num_gpus_per_node} should be divisible by " \
f"model parallel size {model_parallel_size}")
def get_microbatch(final_batch_size):
candidate_microbatch = None
for micro_batch in micro_batches:
if final_batch_size // current_num_gpus % micro_batch == 0:
if candidate_microbatch == None:
candidate_microbatch = micro_batch
if prefer_larger and candidate_microbatch < micro_batch:
candidate_microbatch = micro_batch
return candidate_microbatch
dp_size_per_node = num_gpus_per_node // model_parallel_size
final_batch_size, valid_world_size = _get_compatible_gpus_v01(micro_batches,
int(max_acceptable_batch_size/dp_size_per_node),
int(min_gpus/num_gpus_per_node),
int(max_gpus/num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
prefer_larger=prefer_larger)
final_batch_size = int(final_batch_size) * dp_size_per_node
valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size]
if current_num_gpus // model_parallel_size in valid_dp_world_size:
candidate_microbatch = get_microbatch(final_batch_size)
return final_batch_size, valid_dp_world_size, candidate_microbatch
current_dp_size = (current_num_gpus / num_gpus_per_node) * dp_size_per_node
candidate_batch_sizes = []
for micro_batch in micro_batches:
min_batch_size = micro_batch * current_dp_size
factor = math.floor(max_acceptable_batch_size / float(min_batch_size))
candidate_batch_sizes.append(factor * min_batch_size)
used_microbatch = None
if prefer_larger:
candidate_batch_size = max(candidate_batch_sizes)
else:
candidate_batch_size = min(candidate_batch_sizes)
candidate_microbatch = get_microbatch(candidate_batch_size)
return candidate_batch_size, [int(current_dp_size)], candidate_microbatch
def _compatible_ds_version_check(target_deepspeed_version: str):
min_version = pkg_version.parse(MINIMUM_DEEPSPEED_VERSION)
target_version = pkg_version.parse(target_deepspeed_version)
......@@ -223,7 +284,10 @@ def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
"guarantee resource scheduler will scale this job using compatible GPU counts.")
def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0):
def compute_elastic_config(ds_config: dict,
target_deepspeed_version: str,
world_size=0,
return_microbatch=False):
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below)
DeepSpeed will compute a total train batch size corresponding valid GPU count list that
provides a high level of elasticity. Elasticity in this case means we are safe to scale
......@@ -250,8 +314,9 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world
target_deepspeed_version (str): When called from scheduling
infrastructure we want to ensure that the target deepspeed version is
compatible with the elasticity version used in the backend.
world_size (int, optional): Intended/current world size, will do some sanity
world_size (int, optional): Intended/current DP world size, will do some sanity
checks to ensure world size is actually valid with the config.
return_microbatch (bool, optional): whether to return micro batch size or not.
Raises:
ElasticityConfigError: Missing required elasticity config or elasticity disabled
......@@ -277,6 +342,13 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world
"('enabled':true) if running an elastic training job.")
elastic_config = ElasticityConfig(elastic_config_dict)
model_parallel_size = elastic_config.model_parallel_size
num_gpus_per_node = elastic_config.num_gpus_per_node
if model_parallel_size > 1 and float(elastic_config.version) != 0.2:
raise ElasticityConfigError(f"Elasticity V{elastic_config.version} " \
f"does not support model-parallel training. Given model-parallel size: " \
f"{model_parallel_size}")
if float(elastic_config.version) > LATEST_ELASTICITY_VERSION:
raise ElasticityConfigError("Attempting to run elasticity version " \
......@@ -297,10 +369,39 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world
prefer_larger=elastic_config.prefer_larger_batch_size)
# ensure batch size is int dtype
final_batch_size = int(final_batch_size)
elif float(elastic_config.version) == 0.2:
if world_size != 0:
current_num_gpus = world_size
else:
if "WORLD_SIZE" in os.environ and \
os.getenv('WORLD_SIZE').isnumeric():
current_num_gpus = int(os.getenv('WORLD_SIZE'))
else:
WORLD_SIZE = os.getenv('WORLD_SIZE')
raise ElasticityConfigError(
'Elasticity V 0.2 needs WORLD_SIZE '\
'to compute valid batch size. '\
'Either give it as argument to function compute_elastic_config '\
'or set it as an environment variable. '\
f'Value of WORLD_SIZE as environment variable is {WORLD_SIZE}')
final_batch_size, valid_gpus, candidate_microbatch_size = _get_compatible_gpus_v02(
micro_batches=elastic_config.micro_batches,
max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
current_num_gpus=current_num_gpus,
min_gpus=elastic_config.min_gpus,
max_gpus=elastic_config.max_gpus,
prefer_larger=elastic_config.prefer_larger_batch_size,
num_gpus_per_node=num_gpus_per_node,
model_parallel_size=model_parallel_size)
# ensure batch size is int dtype
final_batch_size = int(final_batch_size)
else:
raise NotImplementedError(
f"Unable to find elastic logic for version: {elastic_config.version}")
logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}")
if world_size > 0:
if world_size not in valid_gpus:
raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \
......@@ -317,4 +418,19 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world
f" micro_batches={elastic_config.micro_batches}."
return final_batch_size, valid_gpus, micro_batch_size
if return_microbatch:
# Pick a valid micro batch size
if float(elastic_config.version) == 0.2:
return final_batch_size, valid_gpus, candidate_microbatch_size
else:
micro_batch_size = None
for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
if final_batch_size // world_size % mbsz == 0:
micro_batch_size = mbsz
break
assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
f" micro_batches={elastic_config.micro_batches}."
return final_batch_size, valid_gpus, micro_batch_size
return final_batch_size, valid_gpus
'''Copyright The Microsoft DeepSpeed Team'''
import torch
def is_torch_elastic_compatible():
'''
Helper to lookup torch version. Elastic training is
introduced in 1.11.x
'''
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 1 and TORCH_MINOR >= 11:
return True
else:
return False
'''Copyright The Microsoft DeepSpeed Team'''
import torch
import deepspeed
import subprocess
import argparse
from .ops.op_builder import ALL_OPS
from .ops.op_builder.all_ops import ALL_OPS
from .git_version_info import installed_ops, torch_info
from .ops import __compatible_ops__ as compatible_ops
from deepspeed.accelerator import get_accelerator
GREEN = '\033[92m'
RED = '\033[91m'
......@@ -54,7 +56,7 @@ def op_report(verbose=True):
def ninja_installed():
try:
import ninja
import ninja # noqa: F401
except ImportError:
return False
return True
......@@ -80,31 +82,33 @@ def nvcc_version():
def debug_report():
max_dots = 33
hip_version = None
if hasattr(torch.version, 'hip'):
hip_version = torch.version.hip
report = [
("torch install path",
torch.__path__),
("torch version",
torch.__version__),
("torch cuda version",
torch.version.cuda),
("torch hip version",
hip_version),
("nvcc version",
(None if hip_version else nvcc_version())),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}")),
)
]
if get_accelerator().device_name() == 'cuda':
hip_version = getattr(torch.version, "hip", None)
report.extend([("torch cuda version",
torch.version.cuda),
("torch hip version",
hip_version),
("nvcc version",
(None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}"))])
else:
report.extend([("deepspeed wheel compiled w.",
f"torch {torch_info['version']} ")])
print("DeepSpeed general environment info:")
for name, value in report:
print(name, "." * (max_dots - len(name)), value)
......
'''Copyright The Microsoft DeepSpeed Team'''
try:
# This is populated by setup.py
from .git_version_info_installed import *
from .git_version_info_installed import * # noqa: F401
except ModuleNotFoundError:
import os
if os.path.isfile('version.txt'):
......@@ -11,7 +13,7 @@ except ModuleNotFoundError:
git_hash = '[none]'
git_branch = '[none]'
from .ops.op_builder import ALL_OPS
from .ops.op_builder.all_ops import ALL_OPS
installed_ops = dict.fromkeys(ALL_OPS.keys(), False)
compatible_ops = dict.fromkeys(ALL_OPS.keys(), False)
torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"}
'''Copyright The Microsoft DeepSpeed Team'''
from .engine import InferenceEngine
'''Copyright The Microsoft DeepSpeed Team'''
import torch
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
from pydantic import validator
from typing import Dict, Union
from enum import Enum
class DtypeEnum(Enum):
# The torch dtype must always be the first value (so we return torch.dtype)
fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16"
fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
int8 = torch.int8, "torch.int8", "int8"
# Copied from https://stackoverflow.com/a/43210118
# Allows us to use multiple values for each Enum index and returns first
# listed value when Enum is called
def __new__(cls, *values):
obj = object.__new__(cls)
# first value is canonical value
obj._value_ = values[0]
for other_value in values[1:]:
cls._value2member_map_[other_value] = obj
obj._all_values = values
return obj
def __repr__(self):
return "<%s.%s: %s>" % (
self.__class__.__name__,
self._name_,
", ".join([repr(v) for v in self._all_values]),
)
class MoETypeEnum(str, Enum):
residual = "residual"
standard = "standard"
class DeepSpeedTPConfig(DeepSpeedConfigModel):
""" Configure tensor parallelism settings """
enabled: bool = True
""" Turn tensor parallelism on/off. """
tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """
mpu: object = None
"""
A model parallelism unit object that implements
``get_{model,data}_parallel_{rank,group,world_size}()``.
"""
tp_group: object = None
class DeepSpeedMoEConfig(DeepSpeedConfigModel):
""" Sets parameters for MoE """
enabled: bool = True
ep_size: int = 1
"""
The expert-parallelism size which is used for partitioning the experts
across the GPUs in the expert-parallel group.
"""
moe_experts: list = Field([1], alias="num_experts")
""" The global number of experts used in an MoE layer. """
type: MoETypeEnum = MoETypeEnum.standard
"""
Specify the type of MoE layer. We have two types of MoE layer: 'Standard'
and 'Residual'.
"""
ep_mp_group: object = None
ep_group: object = Field(None, alias="expert_group")
class QuantTypeEnum(str, Enum):
asym = "asymmetric"
sym = "symmetric"
class BaseQuantConfig(DeepSpeedConfigModel):
enabled = True
num_bits = 8
q_type: QuantTypeEnum = QuantTypeEnum.sym
q_groups: int = 1
class WeightQuantConfig(BaseQuantConfig):
enabled = True
class ActivationQuantConfig(BaseQuantConfig):
enabled = True
class QKVQuantConfig(DeepSpeedConfigModel):
enabled = True
class QuantizationConfig(DeepSpeedConfigModel):
enabled: bool = True
activation: ActivationQuantConfig = ActivationQuantConfig()
weight: WeightQuantConfig = WeightQuantConfig()
qkv: QKVQuantConfig = QKVQuantConfig()
# todo: brainstorm on how to do ckpt loading for DS inference
class InferenceCheckpointConfig(DeepSpeedConfigModel):
checkpoint_dir: str = None
save_mp_checkpoint_path: str = None
base_dir: str = None
class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
""" Sets parameters for DeepSpeed Inference Engine. """
replace_with_kernel_inject: bool = Field(False, alias="kernel_inject")
"""
Set to true to inject inference kernels for models such as, Bert, GPT2,
GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two
linear layers as a tuple:
`(attention_output projection, transformer output projection)`
"""
dtype: DtypeEnum = torch.float16
"""
Desired model data type, will convert model to this type.
Supported target types: `torch.half`, `torch.int8`, `torch.float`
"""
tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp")
"""
Configuration for tensor parallelism used to split the model across several
GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`.
"""
enable_cuda_graph: bool = False
"""
Use this flag for capturing the CUDA-Graph of the inference ops, so that it
can run faster using the graph replay method.
"""
zero: DeepSpeedZeroConfig = {}
"""
ZeRO configuration to use with the Inference Engine. Expects a dictionary
containing values for :any:`DeepSpeedZeroConfig`.
"""
triangular_masking: bool = Field(True, alias="tm")
"""
Controls the type of masking for attention scores in transformer layer.
Note that the masking is application specific.
"""
moe: Union[bool, DeepSpeedMoEConfig] = {}
"""
Specify if the type of Transformer is MoE. Expects a dictionary containing
values for :any:`DeepSpeedMoEConfig`.
"""
quant: QuantizationConfig = {}
"""
NOTE: only works for int8 dtype.
Quantization settings used for quantizing your model using the MoQ. The
setting can be one element or a tuple. If one value is passed in, we
consider it as the number of groups used in quantization. A tuple is passed
in if we want to mention that there is extra-grouping for the MLP part of a
Transformer layer (e.g. (True, 8) shows we quantize the model using 8
groups for all the network except the MLP part that we use 8 extra
grouping). Expects a dictionary containing values for
:any:`QuantizationConfig`.
"""
#todo: refactor the following 3 into the new checkpoint_config
checkpoint: str = None
"""
Path to deepspeed compatible checkpoint or path to JSON with load policy.
"""
base_dir: str = None
"""
This shows the root directory under which all the checkpoint files exists.
This can be passed through the json config too.
"""
save_mp_checkpoint_path: str = None
"""
The path for which we want to save the loaded model with a checkpoint. This
feature is used for adjusting the parallelism degree to help alleviate the
model loading overhead. It does not save any new checkpoint if no path is
passed.
"""
checkpoint_config: InferenceCheckpointConfig = Field({}, alias="ckpt_config")
"""
TODO: Add docs. Expects a dictionary containing values for
:any:`InferenceCheckpointConfig`.
"""
return_tuple: bool = True
"""
Specify whether or not the transformer layers need to return a tuple or a
Tensor.
"""
training_mp_size: int = 1
"""
If loading a checkpoint this is the mp size that it was trained with, it
may be different than what the mp size that you want to use during
inference.
"""
replace_method: str = Field(
"auto",
deprecated=True,
deprecated_msg=
"This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
)
injection_policy: Dict = Field(None, alias="injection_dict")
"""
Dictionary mapping a client nn.Module to its corresponding injection
policy. e.g., `{BertLayer : deepspeed.inference.HFBertLayerPolicy}`
"""
injection_policy_tuple: tuple = None
""" TODO: Add docs """
config: Dict = Field(
None,
alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens")
"""
This argument shows the maximum number of tokens inference-engine can work
with, including the input and output tokens. Please consider increasing it
to the required token-length required for your use-case.
"""
mp_size: int = Field(1, deprecated=True, new_param="tensor_parallel.tp_size")
"""
Desired model parallel size, default is 1 meaning no model parallelism.
Deprecated, please use the ``tensor_parallel` config to control model
parallelism.
"""
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None,
alias="expert_group",
deprecated=True,
new_param="moe.ep_group")
ep_mp_group: object = Field(None,
alias="expert_mp_group",
deprecated=True,
new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
deprecated=True,
new_param="moe.type")
@validator("moe")
def moe_backward_compat(cls, field_value, values):
if isinstance(field_value, bool):
return DeepSpeedMoEConfig(moe=field_value)
return field_value
class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
......@@ -2,22 +2,34 @@
Copyright 2021 The Microsoft DeepSpeed Team
'''
import torch
import time
import os
from deepspeed import comm as dist
from deepspeed.utils.logging import log_dist
from torch.nn.modules import Module
import torch.distributed as dist
from packaging import version as pkg_version
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject.replace_module import replace_transformer_layer
from ..utils import logger, init_distributed
from ..module_inject import replace_transformer_layer, generic_injection
from ..comm.comm import init_distributed
from ..pipe import PipelineModule
from ..moe.utils import has_moe_layers
from ..moe.layer import MoE
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
from deepspeed.accelerator import get_accelerator
from ..module_inject.policy import TransformerPolicy
from ..module_inject.auto_tp import AutoTP
import torch.distributed as dist
import deepspeed.utils.groups as groups
from ..module_inject.replace_policy import generic_policies
DS_INFERENCE_ENABLED = False
from torch import nn
INFERENCE_MODEL_TIMER = "model-forward-inference"
class InferenceEngine(Module):
......@@ -25,42 +37,11 @@ class InferenceEngine(Module):
inference_ep_group = None
expert_mp_group = None
def __init__(self,
model,
triangular_masking=True,
mp_size=1,
training_mp_size=1,
ep_size=1,
mpu=None,
ep_group=None,
expert_mp_group=None,
checkpoint=None,
dtype=None,
injection_dict=None,
return_tuple=True,
replace_method='auto',
quantization_setting=None,
replace_with_kernel_inject=False,
moe=False,
moe_experts=1,
moe_type='standard',
config=None):
def __init__(self, model, config):
"""
Args:
model: torch.nn.Module
mp_size: model-parallel size
mpu: model-parallel unit (used for Megatron-type models)
checkpoint: the json-path, showing the address of model-checkpoints
Example: {type: 'Megatron', 'checkpoints': [ckpt_mp0.pt, ckpt_mp1.pt], 'version': 1.0}
dtype: data-type by which inference is executed
injection_dict: the dictionary that shows the injection policy:
Example: {BertLayer: HFBertLayerPolicy}
return_tuple: if true, inference-API returns a tuple, otherwise a tensor
replace_method: the injection method, this can be passed as auto if no injection-policy is defined, in which case the injection is automatic based on the available policies
quantization_setting:
one of None, Tuple(mlp_extra_grouping, quantize_groups), quantize_groups
replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
config: DeepSpeedInferenceConfig
"""
global DS_INFERENCE_ENABLED
DS_INFERENCE_ENABLED = True
......@@ -68,90 +49,165 @@ class InferenceEngine(Module):
super().__init__()
self.module = model
self._config = config
self._get_model_config_generate(config) # keep for weird backward compatibility
self._get_model_config_generate(config)
# patch model generate with ours if model uses it
if hasattr(self.module, "generate"):
self.generate = self._generate
self.mp_world_size = mp_size
self.checkpoint = checkpoint
self.dtype = dtype
self.injection_dict = injection_dict
self.mp_group = None
self.mpu = mpu
self._validate_args(mpu)
self.replace_method = replace_method
if hasattr(self.module, "config"):
TransformerPolicy.hf_model_config = self.module.config
# todo: keep this self.injection_dict because we don't use to change config.injection_policy API
# todo: this will get changed when Molly's PR on auto injection dict is merged
self.injection_dict = config.injection_policy
# todo: refactor the mp_group and mp_size related in the next refactor
self.mp_group = config.tensor_parallel.tp_group
self.mpu = config.tensor_parallel.mpu
#self._validate_args(self.mpu, config.replace_with_kernel_inject)
self.quantize_merge_count = 1
self.quantization_scales = None
self.triangular_masking = triangular_masking
self.ep_size = ep_size
self.ep_group = ep_group
self.expert_mp_group = expert_mp_group
self._init_quantization_setting(quantization_setting)
# these are not needed in the config as we are creating them ourselves in the inference engine
self.ep_group = None # config.moe.ep_group
self.expert_mp_group = None # config.moe.ep_mp_group
self.cuda_graph_created = False
self.checkpoint_engine = TorchCheckpointEngine()
quantization_setting = None
self._init_quantization_setting(
quantization_setting
) # todo: update with the new quant config for weight quant
self.model_profile_enabled = False
self._model_times = []
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self.remove_mask_prepare_for_bloom()
if self.checkpoint:
self._load_checkpoint(self.checkpoint)
if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
"If you want to use cuda graph, please upgrade torch to at least v1.10"
if config.checkpoint and not config.replace_with_kernel_inject:
self._load_checkpoint(config.checkpoint)
# convert model to intended dtype
if self.dtype:
self._convert_to_dtype()
if config.dtype:
self._convert_to_dtype(config)
if self.mpu:
self.mp_world_size = dist.get_world_size(
config.tensor_parallel.tp_size = dist.get_world_size(
group=self.mpu.get_model_parallel_group())
self.mp_group = mpu.get_model_parallel_group()
elif self.mp_world_size > 1:
self._create_model_parallel_group()
self.mp_group = self.mpu.get_model_parallel_group()
elif config.tensor_parallel.tp_size > 1:
self._create_model_parallel_group(config)
config.tensor_parallel.tp_group = self.mp_group
moe, _ = has_moe_layers(self.module)
if isinstance(self.module, torch.nn.Module):
moe, _ = has_moe_layers(self.module)
else:
moe = False
if moe and dist.get_world_size() > 1:
self._create_ep_parallel_group(moe_experts)
self._create_ep_parallel_group(config.moe.moe_experts)
# retain this from the old conditional argument being passed to apply_injection_policy()
if not config.replace_with_kernel_inject:
config.checkpoint = None
# We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism.
if self.injection_dict:
# 1. User specified Tensor Parallelism
assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection"
for client_module, injection_policy in self.injection_dict.items():
self._apply_injection_policy(client_module,
injection_policy,
return_tuple,
replace_with_kernel_inject,
moe,
moe_experts,
moe_type,
training_mp_size)
elif replace_method == 'auto':
self._apply_injection_policy(
return_tuple=return_tuple,
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size)
device = torch.cuda.current_device()
logger.info(f"Place model to device: {device}")
# construct the tuple and pass that instead of a string or dict.
if isinstance(injection_policy, str):
config.injection_policy_tuple = (injection_policy, )
else:
config.injection_policy_tuple = injection_policy
self._apply_injection_policy(config, client_module)
else:
if config.replace_with_kernel_inject:
# 2. DeepSpeed Kernel Injection
self._apply_injection_policy(config)
else:
# 3. Automatic Tensor Parallelism
parser_dict = AutoTP.tp_parser(model)
print("AutoTP: ", parser_dict)
for client_module, injection_policy in parser_dict:
if isinstance(injection_policy, str):
config.injection_policy_tuple = (injection_policy, )
else:
config.injection_policy_tuple = injection_policy
self._apply_injection_policy(config, client_module)
device = get_accelerator().current_device_name()
self.module.to(device)
if self.mp_world_size > 1:
self.model_orig_fwd = self.module.forward
self.module.forward = self.forward
else:
if config.tensor_parallel.tp_size > 1:
_rng_state = get_accelerator().get_rng_state().to(
get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0)
get_accelerator().set_rng_state(_rng_state.cpu())
if config.tensor_parallel.tp_size > 1:
assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
# Check if local CUDA graphs can be created in replacement modules
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
self.module.register_forward_pre_hook(self._pre_forward_hook)
self.module.register_forward_hook(self._post_forward_hook)
self.model_profile_enabled = True
self.use_cuda_events = use_cuda_events
if self.use_cuda_events:
self.timers = SynchronizedWallClockTimer()
# todo: remove this once all the config dicts are centralized from top level pydantic config
def _get_model_config_generate(self, config):
self.config = getattr(self.module, 'config', None) if config is None else config
self.generate = getattr(self.module, 'generate', None)
# this is being passed to replace_transformer_layer(config=self.user_model_config_dict)
self.config = getattr(self.module,
'config',
None) if config.config is None else config.config
def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, '_prepare_attn_mask'):
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
def _pre_forward_hook(self, module, *inputs, **kwargs):
if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).start()
else:
get_accelerator().synchronize()
self._start = time.time()
def _post_forward_hook(self, module, input, output):
if self.use_cuda_events:
self.timers(INFERENCE_MODEL_TIMER).stop()
elapsed_time = self.timers(INFERENCE_MODEL_TIMER).elapsed(reset=True)
else:
get_accelerator().synchronize()
self._end = time.time()
elapsed_time = self._end - self._start
self._model_times.append(elapsed_time)
def _create_model_parallel_group(self):
def _create_model_parallel_group(self, config):
# Call the init process
if InferenceEngine.inference_mp_group is None:
init_distributed()
local_rank = int(os.getenv('LOCAL_RANK', '0'))
torch.cuda.set_device(local_rank)
get_accelerator().set_device(local_rank)
ranks = [i for i in range(self.mp_world_size)]
ranks = [i for i in range(config.tensor_parallel.tp_size)]
self.mp_group = dist.new_group(ranks)
InferenceEngine.inference_mp_group = self.mp_group
else:
self.mp_group = InferenceEngine.inference_mp_group
......@@ -194,66 +250,121 @@ class InferenceEngine(Module):
self.quantize_groups = quantization_setting
elif quantization_setting is not None:
self.quantize_groups = quantization_setting
logger.info(f"quantize_bits = {self.quantize_bits} "
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}")
def _validate_args(self, mpu):
if not isinstance(self.module, Module):
log_dist(
f"quantize_bits = {self.quantize_bits} "
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}",
[0])
# TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self.mp_world_size, int) or self.mp_world_size < 1:
raise ValueError(f"mp_size must be an int >= 1, got {self.mp_world_size}")
if not isinstance(self._config.tensor_parallel.tp_size,
int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(
f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}"
)
if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self.checkpoint is not None and not isinstance(self.checkpoint, str):
if self._config.checkpoint is not None and not isinstance(
self._config.checkpoint,
(str,
dict)):
raise ValueError(
f"checkpoint must be None or a str, got {type(self.checkpoint)}")
f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}"
)
supported_dtypes = [None, torch.half, torch.int8, torch.float]
if self.dtype not in supported_dtypes:
if self._config.dtype not in supported_dtypes:
raise ValueError(
f"{self.dtype} not supported, valid dtype: {supported_dtypes}")
f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError(
f"injection_dict must be None or a dict, got: {self.injection_dict}")
def _apply_injection_policy(self,
client_module=None,
injection_policy=None,
return_tuple=True,
replace_with_kernel_inject=False,
moe=False,
moe_experts=1,
moe_type='standard',
training_mp_size=1):
replace_transformer_layer(client_module,
self.module,
triangular_masking=self.triangular_masking,
policy=injection_policy,
mp_size=self.mp_world_size,
mp_group=self.mp_group,
ep_group=self.ep_group,
expert_mp_group=self.expert_mp_group,
config=self.config,
fp16=(self.dtype == torch.half),
training=False,
return_tuple=return_tuple,
quantize=(self.dtype == torch.int8),
quantize_settings=(self.quantization_scales,
self.quantize_merge_count,
self.mlp_extra_grouping,
self.quantize_groups),
replace_with_kernel_inject=replace_with_kernel_inject,
moe=moe,
moe_experts=moe_experts,
moe_type=moe_type,
training_mp_size=training_mp_size)
def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group,
mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
error_msgs = []
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if 'query_key_value' in prefix:
module.weight = self.mp_replace.qkv_copy(
module.weight.data,
state_dict[prefix + 'weight'])
else:
module.weight = self.mp_replace.copy(module.weight.data,
state_dict[prefix + 'weight'])
else:
module.norm.weight = self.mp_replace.copy(module.norm.weight.data,
state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
module.norm.bias = self.mp_replace.copy(module.norm.bias,
state_dict[prefix + 'bias'])
else:
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
module.bias = self.mp_replace.copy(module.bias, data)
layer_policies = {
nn.Linear: load,
nn.Embedding: load,
nn.LayerNorm: load,
LinearLayer: load,
LinearAllreduce: load
}
def load_module_recursive(module, prefix='', level=0):
for name, child in module.named_children():
if child.__class__ in layer_policies:
checking_key = prefix + name + '.'
if not any(checking_key in item for item in self.key_list):
continue
if len(list(child.parameters())) > 0 and list(
child.parameters())[0].numel() == 0:
if len(child.weight.ds_shape) == 1:
child = Normalize(dim=child.weight.ds_shape[-1],
dtype=child.weight.dtype,
eps=child.eps)
setattr(module, name, child)
load(child, self.sd, prefix + name + '.')
else:
load_module_recursive(child,
prefix if level == 0 else prefix + name + '.',
level + 1)
load_module_recursive(r_module)
def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
checkpoint = SDLoaderFactory.get_sd_loader_json(
checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
generic_injection(self.module,
fp16=(config.dtype == torch.half)
or (config.dtype == torch.int8),
enable_cuda_graph=config.enable_cuda_graph)
if isinstance(self.module, torch.nn.Module):
# config is our DeepSpeedInferenceConfig and self.config is the HF model config
replace_transformer_layer(client_module,
self.module,
checkpoint,
config,
self.config)
def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
......@@ -283,7 +394,7 @@ class InferenceEngine(Module):
if is_pipe_parallel:
raise RuntimeError(
'pipeline parallelism is currently not supported in inference.')
if os.path.isdir(load_dir):
if not isinstance(load_dir, dict) and os.path.isdir(load_dir):
if tag is None:
latest_path = os.path.join(load_dir, "latest")
if os.path.isfile(latest_path):
......@@ -291,38 +402,54 @@ class InferenceEngine(Module):
tag = fd.read().strip()
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size,
mp_rank,
is_pipe_parallel=is_pipe_parallel,
quantize=(self.dtype is torch.int8),
quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping)
self.quantization_scales, self.quantize_merge_count = quantize_config
moe, _ = has_moe_layers(self.module)
if moe:
from deepspeed.runtime.engine import DeepSpeedEngine
old_moe_load = False
if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True
DeepSpeedEngine.load_moe_state_dict(
load_dir,
tag,
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)
if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
for i in range(1, len(sd_loader)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i],
map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
self.module.load_state_dict(
state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size,
mp_rank,
is_pipe_parallel=is_pipe_parallel,
quantize=(self._config.dtype is torch.int8),
quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping)
self.quantization_scales, self.quantize_merge_count = quantize_config
moe, _ = has_moe_layers(self.module)
if moe:
from deepspeed.runtime.engine import DeepSpeedEngine
old_moe_load = False
if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True
DeepSpeedEngine.load_moe_state_dict(
load_dir,
tag,
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
checkpoint_engine=self.checkpoint_engine)
self.module.load_state_dict(
state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
def _choose_module_key(self, sd):
assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
......@@ -332,25 +459,84 @@ class InferenceEngine(Module):
elif 'model' in sd:
return 'model'
def _convert_to_dtype(self):
if self.dtype is torch.int8 and self.quantization_scales is None:
def _convert_to_dtype(self, config):
if not isinstance(self.module, torch.nn.Module):
return
if False: #config.dtype is torch.int8 and self.quantization_scales is None:
quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
model, self.quantization_scales = quantizer.model_quantize(self.module,
self.injection_dict,
self.quantize_bits,
self.quantize_groups)
elif self.dtype == torch.half:
elif config.dtype == torch.half:
self.module.half()
elif self.dtype == torch.float:
elif config.dtype == torch.bfloat16:
self.module.bfloat16()
elif config.dtype == torch.float:
self.module.float()
def _pre_forward_hook(self, module, *inputs, **kwargs):
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = get_accelerator().Stream()
cuda_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(cuda_stream):
for i in range(3):
ret = self.module(*inputs, **kwargs)
get_accelerator().current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs = torch.cuda.CUDAGraph()
self.static_inputs = inputs
self.static_kwargs = kwargs
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self.module(*self.static_inputs, **self.static_kwargs)
self.cuda_graph_created = True
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
self.static_kwargs[k].copy_(kwargs[k])
self._cuda_graphs.replay()
return self.static_output
def model_times(self):
assert self.model_profile_enabled, "model profiling is not enabled"
model_times = self._model_times
if self._config.enable_cuda_graph and len(self._model_times) == 0:
raise ValueError(
"Model times are empty and cuda graph is enabled. If "
"this is a GPT-style model this combo is not supported. If this is a "
"BERT-style model this is a bug, please report it. "
f"Model type is: {type(self.module)}")
self._model_times = []
return model_times
def _module_match(self, module):
for policy in generic_policies:
policy = policy()
if policy.match_replaced(module):
return True
return False
def _local_cuda_graph_used(self, module):
if isinstance(module, torch.nn.Module):
return False
else:
sub_module_cuda_graph = False
for name in module.__dict__.keys():
sub_module = getattr(module, name)
if self._module_match(sub_module) and hasattr(sub_module,
"enable_cuda_graph"):
sub_module_cuda_graph = True
return sub_module_cuda_graph
def forward(self, *inputs, **kwargs):
"""Execute forward propagation
......@@ -359,22 +545,44 @@ class InferenceEngine(Module):
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if self.mp_world_size > 1:
if self.mpu is None:
for input in inputs:
if torch.is_tensor(input):
input = input.to(torch.cuda.current_device())
if not input.is_contiguous():
input = input.contiguous()
dist.broadcast(input, 0)
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
if not kwargs[k].is_contiguous():
kwargs[k] = kwargs[k].contiguous()
dist.broadcast(kwargs[k], 0)
outputs = self.model_orig_fwd(*inputs, **kwargs)
start = None
if self.model_profile_enabled and get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph:
get_accelerator().synchronize()
start = time.time()
if get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)
if self.model_profile_enabled and self._config.enable_cuda_graph:
get_accelerator().synchronize()
duration = time.time() - start
self._model_times.append(duration)
return outputs
def _generate(self, *inputs, **kwargs):
# Reset KV-cache at the beginning of generate
if hasattr(self.module, 'reset_cache'):
self.module.reset_cache()
num_beams = 1
if "generation_config" in kwargs:
gen_config = kwargs["generation_config"]
num_beams = getattr(gen_config, "num_beams", 1)
if "num_beams" in kwargs:
num_beams = kwargs["num_beams"]
if num_beams > 1:
raise NotImplementedError(
"DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
)
return self.module.generate(*inputs, **kwargs)
'''Copyright The Microsoft DeepSpeed Team'''
......@@ -4,6 +4,9 @@ PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024
OPENMPI_LAUNCHER = 'openmpi'
MPICH_LAUNCHER = 'mpich'
SLURM_LAUNCHER = 'slurm'
MVAPICH_LAUNCHER = 'mvapich'
MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile'
ELASTIC_TRAINING_ID_DEFAULT = "123456789"
# Copyright 2020 The Microsoft DeepSpeed Team
"""
DeepSpeed launcher, this is similar to torch.distributed.launch but supports
DeepSpeed launcher, this is similar to torch's distributed.launch but supports
additional features such as arbitrary gpu exclusion.
deepspeed.launcher.launch is intended to be run on a single worker node and
......@@ -15,11 +15,15 @@ import json
import base64
import time
import signal
import psutil
from collections import defaultdict
from typing import Dict
from argparse import ArgumentParser, REMAINDER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import DLTS_POD_ENV_PATH
from ..utils import logger
from ..elasticity import is_torch_elastic_compatible
from .constants import ELASTIC_TRAINING_ID_DEFAULT
PID_FILE_BASEPATH = "/tmp"
......@@ -64,6 +68,20 @@ def parse_args():
help="Skip prepending the training script with "
"'python' - just execute it directly.")
parser.add_argument("--enable_elastic_training",
action="store_true",
help="Enable elastic training support.")
parser.add_argument("--min_elastic_nodes",
type=int,
default=-1,
help="Min number of nodes in elastic training.")
parser.add_argument("--max_elastic_nodes",
type=int,
default=-1,
help="Max number of nodes in elastic training.")
parser.add_argument("--no_local_rank",
action="store_true",
help="Do not pass local_rank as an argument when calling "
......@@ -74,6 +92,12 @@ def parse_args():
default=0,
help="main launching process pid, for internal pid tracking")
parser.add_argument(
"--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
# positional
parser.add_argument("training_script",
type=str,
......@@ -87,6 +111,21 @@ def parse_args():
return parser.parse_args()
# Adapted from https://psutil.readthedocs.io/en/latest/#kill-process-tree
def terminate_process_tree(pid):
process = psutil.Process(pid)
children = process.children(recursive=True)
children.append(process)
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass
gone, alive = psutil.wait_procs(children, timeout=30)
for p in alive:
p.kill()
def main():
args = parse_args()
current_env = os.environ.copy()
......@@ -143,15 +182,93 @@ def main():
with open(pid_file, 'w') as fd:
fd.write(f"{launcher_pid}")
if not is_torch_elastic_compatible():
if args.enable_elastic_training:
logger.info(f"Disabling elastic training support as \
PyTorch version should be greater than 1.11.x")
args.enable_elastic_training = False
if os.path.exists(DLTS_POD_ENV_PATH):
with open(DLTS_POD_ENV_PATH) as file:
lines = file.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
if line.startswith('export FC_TASKROLE_NAME') or line.startswith(
'export FC_TASK_INDEX'):
key_val = line.split()[1]
key, val = key_val.split('=')
current_env[key] = val
processes = []
cmd = []
for local_rank in range(0, num_local_procs):
# each process's rank
dist_rank = global_rank_mapping[local_node][local_rank]
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)
# spawn the processes
if not args.enable_elastic_training:
if args.enable_each_rank_log != "None":
# prepare the log path and the file name prefix
if os.path.isfile(args.enable_each_rank_log):
raise ValueError(
f"{args.enable_each_rank_log} should not be a file, it should be a directory."
)
if not os.path.exists(args.enable_each_rank_log):
try:
os.makedirs(args.enable_each_rank_log)
except Exception as e:
print(e)
raise ValueError(
f"unable to create directory {args.enable_each_rank_log} for each rank log."
)
log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime())
for local_rank in range(0, num_local_procs):
# each process's rank
dist_rank = global_rank_mapping[local_node][local_rank]
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)
# spawn the processes
cmd = []
if not args.no_python:
cmd = [sys.executable, "-u"]
if args.module:
cmd.append("-m")
else:
if args.module:
raise ValueError("Don't use both the '--no_python' flag"
" and the '--module' flag at the same time.")
cmd.append(args.training_script)
# A user may not want to pass local_rank as a keyword arg so we make this optional.
if not args.no_local_rank:
cmd.append(f"--local_rank={local_rank}")
cmd += args.training_script_args
if args.enable_each_rank_log != "None":
log_file = os.path.join(args.enable_each_rank_log,
f"{log_name_prefix}_rank{dist_rank}.log")
log_fd = open(log_file, 'w')
process = subprocess.Popen(cmd,
env=current_env,
stdout=log_fd,
stderr=log_fd)
else:
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
else:
from ..elasticity import DSElasticAgent
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.agent.server.api import WorkerSpec
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic.multiprocessing import Std
if args.min_elastic_nodes == -1:
args.min_elastic_nodes = 1
if args.max_elastic_nodes == -1:
args.max_elastic_nodes = args.nnodes
assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive"
current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
# Get config and arguments
cmd = []
if not args.no_python:
cmd = [sys.executable, "-u"]
......@@ -162,13 +279,36 @@ def main():
raise ValueError("Don't use both the '--no_python' flag"
" and the '--module' flag at the same time.")
cmd.append(args.training_script)
# A user may not want to pass local_rank as a keyword arg so we make this optional.
if not args.no_local_rank:
cmd.append(f"--local_rank={local_rank}")
cmd += args.training_script_args
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
cmd_args = cmd[1:]
rdzv_configs: Dict[str, str] = {'timeout': 100}
run_id = os.environ.get("ELASTIC_RUN_ID", ELASTIC_TRAINING_ID_DEFAULT)
# Creating config for rendezvous class
rdzv_parameters = RendezvousParameters(backend='c10d',
endpoint=args.master_addr + ":" +
str(args.master_port),
run_id=run_id,
min_nodes=args.min_elastic_nodes,
max_nodes=args.max_elastic_nodes,
**rdzv_configs)
spec = WorkerSpec(
role='trainer',
local_world_size=num_local_procs,
entrypoint=cmd[0],
args=cmd[1:],
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
max_restarts=100,
monitor_interval=5,
redirects=Std.from_str("0"),
tee=Std.from_str("0"),
master_addr=None,
master_port=None,
)
agent = DSElasticAgent(spec, current_env)
agent.run()
sig_names = {2: "SIGINT", 15: "SIGTERM"}
last_return_code = None
......@@ -177,7 +317,7 @@ def main():
for process in processes:
logger.info(f"Killing subprocess {process.pid}")
try:
process.kill()
terminate_process_tree(process.pid)
except Exception:
pass
if last_return_code is not None:
......
'''Copyright The Microsoft DeepSpeed Team'''
import os
import sys
import shutil
import subprocess
import warnings
from shlex import quote
from shlex import split
from abc import ABC, abstractmethod
from deepspeed.accelerator import get_accelerator
from ..utils import logger
from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
......@@ -66,7 +68,14 @@ class PDSHRunner(MultiNodeRunner):
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]
pdsh_cmd_args = [
'pdsh',
'-S',
'-f',
str(PDSH_MAX_FAN_OUT),
'-w',
active_workers
] + split(self.args.launcher_args)
exports = ""
for key, val in self.exports.items():
......@@ -94,8 +103,16 @@ class PDSHRunner(MultiNodeRunner):
deepspeed_launch.append("--no_local_rank")
if self.args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
if self.args.elastic_training:
deepspeed_launch.append("--enable_elastic_training")
deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}")
deepspeed_launch.append(f"--min_elastic_nodes={self.args.min_elastic_nodes}")
cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]]
kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]]
return pdsh_cmd_args + deepspeed_launch + [self.user_script
] + self.user_arguments
] + self.user_arguments, kill_command
class OpenMPIRunner(MultiNodeRunner):
......@@ -137,7 +154,7 @@ class OpenMPIRunner(MultiNodeRunner):
'--mca',
'btl_tcp_if_include',
'eth0',
]
] + split(self.args.launcher_args)
export_cmd = []
for k, v in self.exports.items():
......@@ -153,6 +170,102 @@ class OpenMPIRunner(MultiNodeRunner):
] + self.user_arguments
class MPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
def backend_exists(self):
#TODO: if IB is available we should suggestion mpich
return shutil.which('mpirun') #mpich_info
@property
def name(self):
return "mpich"
def validate_args(self):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(
f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(
f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values()
total_process_count = sum(devices_per_node)
process_per_node = list(devices_per_node)[0]
mpirun_cmd = [
'mpirun',
'-n',
f'{total_process_count}',
'-ppn',
f'{process_per_node}',
] + split(self.args.launcher_args)
export_cmd = []
for k, v in self.exports.items():
export_cmd += ['-x', "{}={}".format(k, v)]
python_exec = []
if not self.args.no_python:
python_exec = [sys.executable, "-u"]
if self.args.module:
python_exec.append("-m")
return mpirun_cmd + python_exec + [self.user_script] + self.user_arguments
class SlurmRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
def backend_exists(self):
return shutil.which('sinfo')
@property
def name(self):
return 'slurm'
def get_cmd(self, environment, active_resources):
assert not getattr(self.args, 'detect_nvlink_pairs', False), "slurm backend does not support remapping visible devices"
total_process_count = sum(self.resource_pool.values())
srun_cmd = [
'srun',
'-n',
f'{total_process_count}',
] + split(self.args.launcher_args)
if getattr(self.args, 'slurm_comment', ''):
srun_cmd += ['--comment', self.args.slurm_comment]
if self.args.include != "":
srun_cmd.append('--include')
srun_cmd.append(f'{self.args.include}')
if self.args.exclude != "":
srun_cmd.append('--exclude')
srun_cmd.append(f'{self.args.exclude}')
if self.args.num_nodes > 0:
srun_cmd.append('--nodes')
srun_cmd.append(f'{self.args.num_nodes}')
if self.args.num_gpus > 0:
srun_cmd.append('--gpus')
srun_cmd.append(f'{self.args.num_gpus}')
exports = '--export=ALL'
for key, val in self.exports.items():
exports += f",{key}={val}"
python_exec = [sys.executable, "-u"]
command = srun_cmd + [exports] + python_exec + [self.user_script
] + self.user_arguments
return command
class MVAPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
......@@ -165,7 +278,8 @@ class MVAPICHRunner(MultiNodeRunner):
self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1')
# Enabled cuda-aware communication
self.add_export('MV2_USE_CUDA', '1')
if get_accelerator().device_name() == 'cuda':
self.add_export('MV2_USE_CUDA', '1')
# Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/
self.add_export('MV2_SUPPORT_DL', '1')
......@@ -227,7 +341,7 @@ class MVAPICHRunner(MultiNodeRunner):
f'{process_per_node}',
'--hostfile',
f'{MVAPICH_TMP_HOSTFILE}',
]
] + split(self.args.launcher_args)
export_cmd = []
for k, v in self.exports.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment