Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from torch import nn from torch import nn
from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
...@@ -12,6 +15,7 @@ from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding ...@@ -12,6 +15,7 @@ from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
import torch import torch
import gc import gc
from deepspeed.accelerator import get_accelerator from deepspeed.accelerator import get_accelerator
import re
def load_model_with_checkpoint(r_module, def load_model_with_checkpoint(r_module,
...@@ -24,6 +28,15 @@ def load_model_with_checkpoint(r_module, ...@@ -24,6 +28,15 @@ def load_model_with_checkpoint(r_module,
container=None): container=None):
error_msgs = [] error_msgs = []
def prefix_check():
# if keys start with 'model.', don't skip level 0 prefix
for key in sd[0].keys():
if re.match("^model[.]", key):
return False
return True
skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix
def transpose(data): def transpose(data):
with torch.no_grad(): with torch.no_grad():
data = data.contiguous() data = data.contiguous()
...@@ -40,10 +53,8 @@ def load_model_with_checkpoint(r_module, ...@@ -40,10 +53,8 @@ def load_model_with_checkpoint(r_module,
if prefix + 'bias' in sd[0].keys(): if prefix + 'bias' in sd[0].keys():
if module.bias.data.is_meta: if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter( module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
data=torch.empty_like(module.bias.data, requires_grad=module.bias.data.requires_grad)
device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias']) module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
args = None args = None
gc.collect() gc.collect()
...@@ -61,86 +72,62 @@ def load_model_with_checkpoint(r_module, ...@@ -61,86 +72,62 @@ def load_model_with_checkpoint(r_module,
# set the quantizer number of groups using the checkpoint scale shape # set the quantizer number of groups using the checkpoint scale shape
weight_quantizer.num_groups = scale.shape[0] weight_quantizer.num_groups = scale.shape[0]
else: else:
tmp_data = sd[0][prefix + n].to( tmp_data = sd[0][prefix + n].to(get_accelerator().current_device_name())
get_accelerator().current_device_name())
scale = None scale = None
src_shape = tmp_data.shape src_shape = tmp_data.shape
dst_shape = p.shape dst_shape = p.shape
inner_dim = 1 if tmp_data.dtype == torch.int8 else 0 inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
outer_dim = 0 if tmp_data.dtype == torch.int8 else 1 outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
if (len(src_shape) == 2 and len(dst_shape) == 2): if (len(src_shape) == 2 and len(dst_shape) == 2):
if (src_shape[inner_dim] == dst_shape[0] if (src_shape[inner_dim] == dst_shape[0] and src_shape[outer_dim] == dst_shape[1]):
and src_shape[outer_dim] == dst_shape[1]):
if tmp_data.dtype != torch.int8: if tmp_data.dtype != torch.int8:
p = weight_quantizer.quantize( p = weight_quantizer.quantize(
transpose(tmp_data) if weight_quantizer. transpose(tmp_data) if weight_quantizer.q_int8 else tmp_data)
q_int8 else tmp_data)
else: else:
p = torch.nn.parameter.Parameter(tmp_data, p = torch.nn.parameter.Parameter(tmp_data, requires_grad=False)
requires_grad=False)
p.scale = scale p.scale = scale
setattr(module, n, p) setattr(module, n, p)
else: else:
dim = inner_dim if src_shape[inner_dim] != dst_shape[ dim = inner_dim if src_shape[inner_dim] != dst_shape[0] else outer_dim
0] else outer_dim
dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1 dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
if src_shape[dim] > dst_shape[dim1]: if src_shape[dim] > dst_shape[dim1]:
weight_partition = torch.split( weight_partition = torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank].to(
tmp_data, get_accelerator().current_device_name())
dst_shape[dim1],
dim=dim)[rank].to(
get_accelerator().current_device_name())
assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \ assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
'''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\ '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!''' Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
scale = scale.view( scale = scale.view(-1)[weight_quantizer.num_groups * (rank + 1):].reshape(
-1)[weight_quantizer.num_groups * weight_quantizer.num_groups, -1).contiguous()
(rank + 1):].reshape(
weight_quantizer.num_groups,
-1).contiguous()
else: else:
assert tmp_data.dtype != torch.int8, \ assert tmp_data.dtype != torch.int8, \
'''Merging of the checkpoints are not supported when using INT8 checkpoint! \ '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
Please use a as many GPUs as TP-size for the checkpoint''' Please use a as many GPUs as TP-size for the checkpoint'''
all_data = [ all_data = [
sd[j][prefix + sd[j][prefix + n] if type(sd[j][prefix + n]) is list else sd[j][prefix + n].to(
n] if type(sd[j][prefix + n]) is list else get_accelerator().current_device_name()) for j in range(len(sd))
sd[j][prefix + n].to(
get_accelerator().current_device_name())
for j in range(len(sd))
] ]
# Check if the weight tensor is for the QKV parameter # Check if the weight tensor is for the QKV parameter
if src_shape[1] == (3 * if src_shape[1] == (3 * src_shape[0]) // ckpt_mp_size:
src_shape[0]) // ckpt_mp_size:
qkv_size = src_shape[outer_dim] // 3 qkv_size = src_shape[outer_dim] // 3
src_split = [ src_split = [
torch.split(src[0].data, torch.split(src[0].data, qkv_size, dim=outer_dim) for src in all_data
qkv_size,
dim=outer_dim)
for src in all_data
] ]
weight_partition = torch.cat([ weight_partition = torch.cat([
torch.cat([qkv_s[i] for qkv_s in src_split], torch.cat([qkv_s[i] for qkv_s in src_split], axis=outer_dim)
axis=outer_dim)
for i in range(len(src_split[0])) for i in range(len(src_split[0]))
], ],
dim=dim) dim=dim)
else: else:
weight_partition = torch.cat([ weight_partition = torch.cat([
ad[0].to( ad[0].to(get_accelerator().current_device_name())
get_accelerator().current_device_name()) if type(ad) is list else ad for ad in all_data
if type(ad) is list else ad
for ad in all_data
], ],
dim=dim) dim=dim)
if tmp_data.dtype == torch.int8: if tmp_data.dtype == torch.int8:
scale = torch.cat([ scale = torch.cat(
ad[1].to( [ad[1].to(get_accelerator().current_device_name()) for ad in all_data],
get_accelerator().current_device_name()) dim=dim)
for ad in all_data
],
dim=dim)
if tmp_data.dtype != torch.int8: if tmp_data.dtype != torch.int8:
weight_partition = weight_quantizer.quantize( weight_partition = weight_quantizer.quantize(
...@@ -148,9 +135,8 @@ def load_model_with_checkpoint(r_module, ...@@ -148,9 +135,8 @@ def load_model_with_checkpoint(r_module,
parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \ parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(weight_partition) weight_quantizer.quantize(weight_partition)
else: else:
weight_partition = torch.nn.parameter.Parameter( weight_partition = torch.nn.parameter.Parameter(weight_partition,
weight_partition, requires_grad=False)
requires_grad=False)
weight_partition.scale = scale weight_partition.scale = scale
setattr(module, n, weight_partition) setattr(module, n, weight_partition)
else: else:
...@@ -158,42 +144,27 @@ def load_model_with_checkpoint(r_module, ...@@ -158,42 +144,27 @@ def load_model_with_checkpoint(r_module,
p.data.copy_(tmp_data) p.data.copy_(tmp_data)
else: else:
if src_shape[0] > dst_shape[0]: if src_shape[0] > dst_shape[0]:
bias_split = torch.split( bias_split = torch.split(tmp_data, dst_shape[-1])[rank].to(
tmp_data, get_accelerator().current_device_name()).contiguous()
dst_shape[-1])[rank].to(get_accelerator(
).current_device_name()).contiguous()
p.data.copy_(bias_split) p.data.copy_(bias_split)
else: else:
# Check if the weight tensor is for the QKV parameter # Check if the weight tensor is for the QKV parameter
if src_shape[0] == (3 * r_module.config.hidden_size if src_shape[0] == (3 * r_module.config.hidden_size) // ckpt_mp_size:
) // ckpt_mp_size:
qkv_size = src_shape[0] // 3 qkv_size = src_shape[0] // 3
src_split = [ src_split = [
torch.split(sd[j][prefix + n], torch.split(sd[j][prefix + n], qkv_size, dim=0) for j in range(len(sd))
qkv_size,
dim=0) for j in range(len(sd))
] ]
p.data.copy_( p.data.copy_(
torch.cat( torch.cat([
[ torch.cat([qkv_s[i] for qkv_s in src_split], axis=0)
torch.cat([ for i in range(len(src_split[0]))
qkv_s[i] for qkv_s in src_split ],
], dim=0).to(get_accelerator().current_device_name()).contiguous())
axis=0)
for i in range(len(src_split[0]))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
else: else:
p.data.copy_( p.data.copy_(
torch.cat( torch.cat([sd[j][prefix + n] for j in range(len(sd))],
[ dim=0).to(get_accelerator().current_device_name()).contiguous())
sd[j][prefix + n]
for j in range(len(sd))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
load_parameters(module, prefix) load_parameters(module, prefix)
for n, child in module.named_children(): for n, child in module.named_children():
...@@ -239,20 +210,16 @@ def load_model_with_checkpoint(r_module, ...@@ -239,20 +210,16 @@ def load_model_with_checkpoint(r_module,
setattr(module, name, child) setattr(module, name, child)
continue continue
child_params = list(child.parameters()) child_params = list(child.parameters())
if len(child_params) > 0 and (child_params[0].numel() == 0 if len(child_params) > 0 and (child_params[0].numel() == 0 or child_params[0].is_meta):
or child_params[0].is_meta):
if child.weight.is_meta: if child.weight.is_meta:
ds_shape = child.weight.shape ds_shape = child.weight.shape
else: else:
ds_shape = child.weight.ds_shape ds_shape = child.weight.ds_shape
if child.__class__ is nn.LayerNorm: if child.__class__ is nn.LayerNorm:
child = Normalize(dim=ds_shape[-1], child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
dtype=child.weight.dtype,
eps=child.eps)
setattr(module, name, child) setattr(module, name, child)
elif child.__class__ is nn.Linear: elif child.__class__ is nn.Linear:
child = LinearLayer(weight_shape=child.weight.shape, child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias)
bias=child.bias)
setattr(module, name, child) setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding: elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape) child = OPTEmbedding(weight_shape=ds_shape)
...@@ -261,8 +228,7 @@ def load_model_with_checkpoint(r_module, ...@@ -261,8 +228,7 @@ def load_model_with_checkpoint(r_module,
ds_id = None ds_id = None
if hasattr(child.weight, 'ds_id'): if hasattr(child.weight, 'ds_id'):
ds_id = child.weight.ds_id ds_id = child.weight.ds_id
child = EmbeddingLayer(weight_shape=ds_shape, child = EmbeddingLayer(weight_shape=ds_shape, dtype=child.weight.dtype)
dtype=child.weight.dtype)
if ds_id is not None: if ds_id is not None:
all_ds_ids[ds_id] = child.weight all_ds_ids[ds_id] = child.weight
setattr(module, name, child) setattr(module, name, child)
...@@ -270,7 +236,7 @@ def load_model_with_checkpoint(r_module, ...@@ -270,7 +236,7 @@ def load_model_with_checkpoint(r_module,
else: else:
load_module_recursive( load_module_recursive(
child, child,
prefix if (level == 0 and ckpt_type == 'pp') and container.policy.use_load_prefix else \ prefix if (level == 0 and ckpt_type == 'pp') and skip_level_0_prefix else \
prefix + name + '.', prefix + name + '.',
level + 1) level + 1)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
...@@ -18,34 +21,25 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal ...@@ -18,34 +21,25 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
Returns: Returns:
Updated nn.module with quantized transformer layers Updated nn.module with quantized transformer layers
""" """
def quantize_weight(weight): def quantize_weight(weight):
return weight.to(torch.int8) return weight.to(torch.int8)
def megatron_layer_quantize(layer): def megatron_layer_quantize(layer):
layer.attention.query_key_value.weight.data = quantize_weight( layer.attention.query_key_value.weight.data = quantize_weight(layer.attention.query_key_value.weight.data)
layer.attention.query_key_value.weight.data) layer.attention.dense.weight.data = quantize_weight(layer.attention.dense.weight.data)
layer.attention.dense.weight.data = quantize_weight( layer.mlp.dense_h_to_4h.weight.data = quantize_weight(layer.mlp.dense_h_to_4h.weight.data)
layer.attention.dense.weight.data) layer.mlp.dense_4h_to_h.weight.data = quantize_weight(layer.mlp.dense_4h_to_h.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(
layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(
layer.mlp.dense_4h_to_h.weight.data)
def bert_layer_quantize(layer): def bert_layer_quantize(layer):
layer.attention.self.query.weight.data = quantize_weight( layer.attention.self.query.weight.data = quantize_weight(layer.attention.self.query.weight.data)
layer.attention.self.query.weight.data) layer.attention.self.key.weight.data = quantize_weight(layer.attention.self.key.weight.data)
layer.attention.self.key.weight.data = quantize_weight( layer.attention.self.value.weight.data = quantize_weight(layer.attention.self.value.weight.data)
layer.attention.self.key.weight.data) layer.attention.output.dense.weight.data = quantize_weight(layer.attention.output.dense.weight.data)
layer.attention.self.value.weight.data = quantize_weight(
layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(
layer.attention.output.dense.weight.data)
if preln: if preln:
layer.intermediate.dense_act.weight.data = quantize_weight( layer.intermediate.dense_act.weight.data = quantize_weight(layer.intermediate.dense_act.weight.data)
layer.intermediate.dense_act.weight.data)
else: else:
layer.intermediate.dense.weight.data = quantize_weight( layer.intermediate.dense.weight.data = quantize_weight(layer.intermediate.dense.weight.data)
layer.intermediate.dense.weight.data)
layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data) layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
def quantize_fn(child): def quantize_fn(child):
...@@ -58,9 +52,7 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal ...@@ -58,9 +52,7 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
return child return child
return quantize_module(model=model, return quantize_module(model=model, orig_class=orig_layer_impl, quantize_fn=quantize_fn)
orig_class=orig_layer_impl,
quantize_fn=quantize_fn)
def quantize_module(model, orig_class, quantize_fn): def quantize_module(model, orig_class, quantize_fn):
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from deepspeed.utils.types import ActivationFuncType from deepspeed.utils.types import ActivationFuncType
import torch import torch
...@@ -70,7 +72,7 @@ class TransformerPolicy(DSPolicy): ...@@ -70,7 +72,7 @@ class TransformerPolicy(DSPolicy):
self.split_qkv = split_qkv self.split_qkv = split_qkv
@abstractmethod @abstractmethod
def attention(self): def attention(self, enable_training=False):
""" """
Returns attention qkv and dense parameters Returns attention qkv and dense parameters
weight: (3*hidden, hidden) and (hidden, hidden) weight: (3*hidden, hidden) and (hidden, hidden)
...@@ -78,6 +80,13 @@ class TransformerPolicy(DSPolicy): ...@@ -78,6 +80,13 @@ class TransformerPolicy(DSPolicy):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_q_k_v(self):
"""
return all q,k,v parameters without merging them together
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_hidden_heads(self): def get_hidden_heads(self):
""" """
...@@ -103,6 +112,14 @@ class TransformerPolicy(DSPolicy): ...@@ -103,6 +112,14 @@ class TransformerPolicy(DSPolicy):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_lora_params(self):
"""
Returns lora parameters used in transformer layer
"""
raise NotImplementedError
# TODO (lekurile): This function exists in base container as well, consolidate as some point # TODO (lekurile): This function exists in base container as well, consolidate as some point
def transpose(data): def transpose(data):
...@@ -124,15 +141,10 @@ def _transpose(x, heads=1, mp_replace=None): ...@@ -124,15 +141,10 @@ def _transpose(x, heads=1, mp_replace=None):
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=-1) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=-1)
if len(q.shape) > 2: if len(q.shape) > 2:
new_shape = (q.shape[0], ) + (-1, ) new_shape = (q.shape[0], ) + (-1, )
return torch.cat((q.reshape(new_shape), return torch.cat((q.reshape(new_shape), k.reshape(new_shape), v.reshape(new_shape)),
k.reshape(new_shape),
v.reshape(new_shape)),
dim=outer_dim).reshape(x.shape) dim=outer_dim).reshape(x.shape)
else: else:
return torch.cat((q.reshape(-1), return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor. # This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
...@@ -156,19 +168,14 @@ def maybe_copy(module, ...@@ -156,19 +168,14 @@ def maybe_copy(module,
else: else:
dst = mp_replace.copy(dst, tmp) dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2: if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter( dst = torch.nn.parameter.Parameter(_transpose(dst, heads=heads, mp_replace=mp_replace).contiguous())
_transpose(dst,
heads=heads,
mp_replace=mp_replace).contiguous())
else: else:
if split_qkv: if split_qkv:
dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \ dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())), int8=weight_quantizer.q_int8) (transpose(tmp).contiguous())), int8=weight_quantizer.q_int8)
else: else:
if qkv and megatron_v2: if qkv and megatron_v2:
tmp = _transpose(transpose(tmp), tmp = _transpose(transpose(tmp), heads=heads, mp_replace=mp_replace).contiguous()
heads=heads,
mp_replace=mp_replace).contiguous()
if weight_quantizer.q_int8: if weight_quantizer.q_int8:
tmp = transpose(tmp) tmp = transpose(tmp)
dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \ dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
...@@ -177,13 +184,7 @@ def maybe_copy(module, ...@@ -177,13 +184,7 @@ def maybe_copy(module,
# Extending the maybe_copy function for when the q, k, and v are in separate parameters! # Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module, def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names, split_qkv=False):
sd,
weight_quantizer,
mp_replace,
dst_name,
src_names,
split_qkv=False):
if src_names[0] in sd: if src_names[0] in sd:
q = sd[src_names[0]] q = sd[src_names[0]]
k = sd[src_names[1]] k = sd[src_names[1]]
...@@ -203,3 +204,19 @@ def maybe_copy_qkv(module, ...@@ -203,3 +204,19 @@ def maybe_copy_qkv(module,
dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \ dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
transpose(qkv_data)), int8=weight_quantizer.q_int8) transpose(qkv_data)), int8=weight_quantizer.q_int8)
setattr(module, dst_name, dst) setattr(module, dst_name, dst)
def pack_lora_weights(p):
return [
p.lora_right_weight, \
p.lora_left_weight, \
p.lora_scaling
]
def maybe_get_lora(p):
if hasattr(p, 'lora_right_weight'):
lora_param = pack_lora_weights(p)
else:
lora_param = []
return lora_param
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os import os
import torch import torch
...@@ -23,6 +26,7 @@ from .utils import policy_to_ds_container ...@@ -23,6 +26,7 @@ from .utils import policy_to_ds_container
class ReplaceWithTensorSlicing: class ReplaceWithTensorSlicing:
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
if mp_group is not None: if mp_group is not None:
self.gpu_index = dist.get_rank(group=mp_group) self.gpu_index = dist.get_rank(group=mp_group)
...@@ -38,7 +42,7 @@ class ReplaceWithTensorSlicing: ...@@ -38,7 +42,7 @@ class ReplaceWithTensorSlicing:
for merging your checkpoints before replacing the transformer layer with\ for merging your checkpoints before replacing the transformer layer with\
inference-kernels' inference-kernels'
def qkv_copy(self, dst, src, int8=False): def qkv_copy(self, dst, src, int8=False, allocat_tensor=False):
if src is None: if src is None:
return src return src
src_shape = src.shape src_shape = src.shape
...@@ -47,6 +51,9 @@ class ReplaceWithTensorSlicing: ...@@ -47,6 +51,9 @@ class ReplaceWithTensorSlicing:
outer_dim = 0 if int8 else -1 outer_dim = 0 if int8 else -1
inner_dim = -1 if int8 else 0 inner_dim = -1 if int8 else 0
if allocat_tensor:
dst = torch.empty_like(dst)
src_split = torch.split(src.data, src.shape[outer_dim] // 3, dim=outer_dim) src_split = torch.split(src.data, src.shape[outer_dim] // 3, dim=outer_dim)
if (len(src_shape) == 2 and len(dst_shape) == 2): if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[outer_dim] == dst_shape[self.out_dim]: if src_shape[outer_dim] == dst_shape[self.out_dim]:
...@@ -55,82 +62,59 @@ class ReplaceWithTensorSlicing: ...@@ -55,82 +62,59 @@ class ReplaceWithTensorSlicing:
if hasattr(src, 'scale'): if hasattr(src, 'scale'):
dst.scale = src.scale dst.scale = src.scale
return dst return dst
if self.out_dim == 1: self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) qkv_size = dst_shape[self.out_dim] // 3
qkv_size = dst_shape[self.out_dim] // 3 qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
qkv_split = [ weight_split = [
torch.split(src_s, torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
qkv_size, ]
dim=outer_dim) for src_s in src_split dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
] weight_split[self.gpu_index].shape)
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=outer_dim) for i in range(len(qkv_split[0]))
]
dst = dst.reshape(-1).data.copy_(
weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
else:
dst.data.copy_(src_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
else: else:
if src_shape[0] == dst_shape[0]: if src_shape[0] == dst_shape[0]:
return torch.nn.parameter.Parameter(src) return torch.nn.parameter.Parameter(src)
if self.out_dim == 1: qkv_size = dst_shape[0] // 3
qkv_size = dst_shape[0] // 3 qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
bias_split = [ dst.data.copy_(bias_split[self.gpu_index].contiguous())
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=0) for i in range(len(qkv_split[0]))
]
dst.data.copy_(bias_split[self.gpu_index].contiguous())
else:
dst.data.copy_(src_split[self.gpu_index].contiguous())
dst = torch.nn.parameter.Parameter(dst, requires_grad=False) dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'): if hasattr(src, 'scale'):
dst.scale = src.scale dst.scale = src.scale
return dst return dst
def copy(self, dst, src, int8=False): def copy(self, dst, src, int8=False, allocat_tensor=False):
if src is None: if src is None:
return src return src
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
if allocat_tensor:
dst = torch.empty_like(dst)
outer_dim = 0 if int8 else 1 outer_dim = 0 if int8 else 1
inner_dim = 1 if int8 else 0 inner_dim = 1 if int8 else 0
src_shape = src.shape src_shape = src.shape
dst_shape = dst.shape dst_shape = dst.shape
if (len(src_shape) == 2 and len(dst_shape) == 2): if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[inner_dim] == dst_shape[ if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
else: else:
if src_shape[inner_dim] != dst_shape[self.in_dim]: if src_shape[inner_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim]) self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
weight_split = torch.split( dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
src, src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
dst_shape[self.in_dim],
dim=inner_dim)[self.gpu_index].contiguous()
else: else:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
weight_split = torch.split( dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
src.data, src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
dst_shape[self.out_dim],
dim=outer_dim)[self.gpu_index].contiguous()
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(
weight_split.shape)
else: else:
if src_shape[0] == dst_shape[0]: if src_shape[0] == dst_shape[0]:
dst.data.copy_(src) dst = src
else: else:
bias_split = torch.split(src.data, dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
dst_shape[-1])[self.gpu_index].contiguous()
dst.data.copy_(bias_split)
dst = torch.nn.parameter.Parameter(dst, requires_grad=False) dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'): if hasattr(src, 'scale'):
dst.scale = src.scale dst.scale = src.scale
return dst return dst
...@@ -150,6 +134,7 @@ def get_transformer_name(replaced_module): ...@@ -150,6 +134,7 @@ def get_transformer_name(replaced_module):
class GroupQuantizer: class GroupQuantizer:
def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0): def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
self.group_size = group_size self.group_size = group_size
self.num_bits = num_bits self.num_bits = num_bits
...@@ -163,8 +148,7 @@ class GroupQuantizer: ...@@ -163,8 +148,7 @@ class GroupQuantizer:
inputs.scale = torch.empty(1) inputs.scale = torch.empty(1)
return inputs return inputs
q_range = 2**self.num_bits q_range = 2**self.num_bits
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[ num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
0] // self.group_size
inputs = inputs.to(get_accelerator().current_device_name()) inputs = inputs.to(get_accelerator().current_device_name())
input_flat = inputs.reshape(num_groups, -1).contiguous() input_flat = inputs.reshape(num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
...@@ -174,31 +158,14 @@ class GroupQuantizer: ...@@ -174,31 +158,14 @@ class GroupQuantizer:
inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous() inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
out = torch.nn.Parameter(inputs_q, requires_grad=False) out = torch.nn.Parameter(inputs_q, requires_grad=False)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim) inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [ input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
inputs_split[i].reshape(num_groups, input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
-1).contiguous() for i in range(2) input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
] scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
input_min = [ for i in range(2)]
torch.min(input_flat[i],
dim=1, out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
keepdim=True)[0].float() for i in range(2) -1).contiguous()
]
input_max = [
torch.max(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
scale1 = [
(torch.max(input_min[i].abs(),
input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
for i in range(2)
]
out.scale = torch.cat([scale.squeeze().unsqueeze(0),
scale1[0],
scale1[1]],
dim=0).reshape(num_groups,
-1).contiguous()
return out return out
...@@ -211,6 +178,7 @@ def _module_match(module): ...@@ -211,6 +178,7 @@ def _module_match(module):
def generic_injection(module, fp16=False, enable_cuda_graph=True): def generic_injection(module, fp16=False, enable_cuda_graph=True):
def replace_attn(child, policy): def replace_attn(child, policy):
policy_attn = policy.attention(child) policy_attn = policy.attention(child)
if policy_attn is None: if policy_attn is None:
...@@ -246,8 +214,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True): ...@@ -246,8 +214,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
attn_module.attn_qkvb = None attn_module.attn_qkvb = None
attn_module.attn_ow.data = transpose(attn_ow.data) attn_module.attn_ow.data = transpose(attn_ow.data)
attn_module.attn_ob.data.copy_( attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
attn_ob.data.to(get_accelerator().current_device_name()))
return attn_module return attn_module
def replace_attn_block(child, policy): def replace_attn_block(child, policy):
...@@ -262,7 +229,10 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True): ...@@ -262,7 +229,10 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
try: try:
import diffusers import diffusers
cross_attention = diffusers.models.attention.CrossAttention if hasattr(diffusers.models.attention, 'CrossAttention'):
cross_attention = diffusers.models.attention.CrossAttention
else:
cross_attention = diffusers.models.attention_processor.Attention
attention_block = diffusers.models.attention.BasicTransformerBlock attention_block = diffusers.models.attention.BasicTransformerBlock
new_policies = { new_policies = {
cross_attention: replace_attn, cross_attention: replace_attn,
...@@ -278,8 +248,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True): ...@@ -278,8 +248,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
# triangular_masking=True, # triangular_masking=True,
# max_out_tokens=8192) # max_out_tokens=8192)
from ..model_implementations.transformers.clip_encoder import DSClipEncoder from ..model_implementations.transformers.clip_encoder import DSClipEncoder
cg_encoder = DSClipEncoder(module.text_encoder, cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
enable_cuda_graph=enable_cuda_graph)
setattr(module, 'text_encoder', cg_encoder) setattr(module, 'text_encoder', cg_encoder)
for name in module.__dict__.keys(): for name in module.__dict__.keys():
sub_module = getattr(module, name) sub_module = getattr(module, name)
...@@ -291,13 +260,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True): ...@@ -291,13 +260,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
for name, child in module.named_children(): for name, child in module.named_children():
_replace_module(child, policy) _replace_module(child, policy)
if child.__class__ in new_policies: if child.__class__ in new_policies:
replaced_module = new_policies[child.__class__](child, replaced_module = new_policies[child.__class__](child, policy)
policy)
setattr(module, name, replaced_module) setattr(module, name, replaced_module)
_replace_module(sub_module, policy) _replace_module(sub_module, policy)
new_module = policy.apply(sub_module, new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
enable_cuda_graph=enable_cuda_graph)
print(f"**** found and replaced {name} w. {type(new_module)}") print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module) setattr(module, name, new_module)
...@@ -305,11 +272,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True): ...@@ -305,11 +272,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
container_g = None container_g = None
def replace_transformer_layer(orig_layer_impl, def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
model,
checkpoint_dict,
config,
model_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer """ Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments: Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
...@@ -334,15 +297,10 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -334,15 +297,10 @@ def replace_transformer_layer(orig_layer_impl,
seed = -1 seed = -1
local_rank = -1 local_rank = -1
mp_replace = ReplaceWithTensorSlicing( mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
mp_group=config.tensor_parallel.tp_group, mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
def replace_with_policy(child, def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
policy_cls,
triangular_masking,
inference=False,
layer_id=0):
policy = policy_cls(child, inference=inference) policy = policy_cls(child, inference=inference)
if not policy.cuda_graph_supported: if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set # policy says cuda graph is not supported raise an error if set
...@@ -364,8 +322,7 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -364,8 +322,7 @@ def replace_transformer_layer(orig_layer_impl,
_container.set_moe(moe) _container.set_moe(moe)
# 2. Set the tensor parallelism config # 2. Set the tensor parallelism config
_container.set_tensor_parallel_config(config.tensor_parallel.tp_size, _container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
config.tensor_parallel.tp_group)
# 3. Initialize tensors # 3. Initialize tensors
_container.initialize_tensors() _container.initialize_tensors()
...@@ -411,25 +368,21 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -411,25 +368,21 @@ def replace_transformer_layer(orig_layer_impl,
if name in all_reduce_linears: if name in all_reduce_linears:
new_weight = torch.empty(( new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0], weight_shape[1] if conv_linear_layer else weight_shape[0],
(weight_shape[0] if conv_linear_layer else weight_shape[1]) // (weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size,
mp_size,
), ),
device=child.weight.device, device=child.weight.device,
dtype=child.weight.dtype) dtype=child.weight.dtype)
if conv_linear_layer: if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous() child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = mp_replace.copy(new_weight, child.weight.data) data = mp_replace.copy(new_weight, child.weight.data)
new_bias = torch.empty((weight_shape[0]), new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype)
device=child.weight.device,
dtype=child.weight.dtype)
if child.bias is not None: if child.bias is not None:
new_bias.data.copy_(child.bias.data) new_bias.data.copy_(child.bias.data)
return LinearAllreduce(data, child.bias if child.bias is None else \ return LinearAllreduce(data, child.bias if child.bias is None else \
torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group) torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
else: else:
new_weight = torch.empty(( new_weight = torch.empty((
(weight_shape[1] if conv_linear_layer else weight_shape[0]) // (weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size,
mp_size,
weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1], weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
), ),
device=child.weight.device, device=child.weight.device,
...@@ -441,43 +394,54 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -441,43 +394,54 @@ def replace_transformer_layer(orig_layer_impl,
new_bias = torch.empty((weight_shape[0] // mp_size), new_bias = torch.empty((weight_shape[0] // mp_size),
device=child.weight.device, device=child.weight.device,
dtype=child.weight.dtype) dtype=child.weight.dtype)
bias_data = None if child.bias is None else mp_replace.copy( bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to(
new_bias, get_accelerator().current_device_name())
child.bias.data).to(get_accelerator().current_device_name()) return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)
return LinearLayer(weight=data.to(
get_accelerator().current_device_name()),
bias=bias_data)
def _slice_embedding(child, name, conv_linear_layer): def _slice_embedding(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0], new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size),
child.weight.shape[1] // mp_size),
device=child.weight.device, device=child.weight.device,
dtype=child.weight.dtype) dtype=child.weight.dtype)
data = mp_replace.copy(new_weight, data = mp_replace.copy(new_weight,
child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \ child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
child.weight.data) child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0], new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size)
child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data) new_embedding.weight.data.copy_(data)
return new_embedding return new_embedding
def update_mp_params(child): def update_mp_params(child):
if hasattr(child, 'n_heads'): if hasattr(child, 'n_heads'):
assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(
child.n_heads, mp_size)
child.n_heads = child.n_heads // mp_size child.n_heads = child.n_heads // mp_size
if hasattr(child, 'inner_dim'): if hasattr(child, 'inner_dim'):
assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format(
child.inner_dim, mp_size)
child.inner_dim = child.inner_dim // mp_size child.inner_dim = child.inner_dim // mp_size
if hasattr(child, 'num_heads'): if hasattr(child, 'num_heads'):
assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format(
child.num_heads, mp_size)
child.num_heads = child.num_heads // mp_size child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'): if hasattr(child, 'num_attention_heads'):
assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format(
child.num_attention_heads, mp_size)
child.num_attention_heads = child.num_attention_heads // mp_size child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'): if hasattr(child, 'num_attn_heads'):
assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format(
child.num_attn_heads, mp_size)
child.num_attn_heads = child.num_attn_heads // mp_size child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'): if hasattr(child, 'all_head_size'):
assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format(
child.all_head_size, mp_size)
child.all_head_size = child.all_head_size // mp_size child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'): if hasattr(child, 'embed_dim'):
assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format(
child.embed_dim, mp_size)
child.embed_dim = child.embed_dim // mp_size child.embed_dim = child.embed_dim // mp_size
if hasattr(child, 'hidden_size'): if hasattr(child, 'hidden_size'):
assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(
child.hidden_size, mp_size)
child.hidden_size = child.hidden_size // mp_size child.hidden_size = child.hidden_size // mp_size
conv_linear_layer = False conv_linear_layer = False
...@@ -499,12 +463,8 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -499,12 +463,8 @@ def replace_transformer_layer(orig_layer_impl,
def _replace_module(r_module, prev_name=''): def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children(): for name, child in r_module.named_children():
if child.__class__ in linear_policies: if child.__class__ in linear_policies:
setattr( setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
r_module, conv_linear_layer))
name,
linear_policies[child.__class__](child,
prev_name + '.' + name,
conv_linear_layer))
else: else:
update_mp_params(child) update_mp_params(child)
_replace_module(child, name) _replace_module(child, name)
...@@ -551,15 +511,10 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -551,15 +511,10 @@ def replace_transformer_layer(orig_layer_impl,
base_dir1 = checkpoint_dict.get('base_dir', config.base_dir) base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
if ckpt_type == 'pp' and type(checkpoint) is list: if ckpt_type == 'pp' and type(checkpoint) is list:
pbar = tqdm.tqdm(total=len(checkpoint), pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)): for i in range(len(checkpoint)):
sd = [ sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
torch.load(os.path.join(base_dir1,
checkpoint[i]),
map_location='cpu')
]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(replaced_module,
sd, sd,
mp_replace, mp_replace,
...@@ -574,22 +529,15 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -574,22 +529,15 @@ def replace_transformer_layer(orig_layer_impl,
tp_split_size = (world_size / ckpt_mp_size) tp_split_size = (world_size / ckpt_mp_size)
sd_offset = int(rank / tp_split_size) sd_offset = int(rank / tp_split_size)
sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
pbar = tqdm.tqdm(total=num_checkpoints, pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
desc=f"Loading {num_checkpoints} checkpoint shards")
for i in range(num_checkpoints): for i in range(num_checkpoints):
pbar.update(1) pbar.update(1)
ckpt_index = i * ckpt_mp_size + sd_offset ckpt_index = i * ckpt_mp_size + sd_offset
ckpt_files = [ ckpt_files = [
os.path.join(base_dir1, os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
ckpt_list[ckpt_index +
j]) if base_dir1 else ckpt_list[ckpt_index +
j]
for j in range(sd_count) for j in range(sd_count)
] ]
sds = [ sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files
]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(replaced_module,
sds, sds,
mp_replace, mp_replace,
...@@ -602,15 +550,13 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -602,15 +550,13 @@ def replace_transformer_layer(orig_layer_impl,
gc.collect() gc.collect()
if "non_tp" in checkpoint: if "non_tp" in checkpoint:
pbar = tqdm.tqdm( pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
total=len(checkpoint["non_tp"]), desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
for i in range(len(checkpoint["non_tp"])): for i in range(len(checkpoint["non_tp"])):
pbar.update(1) pbar.update(1)
ckpt_file = os.path.join(base_dir1, ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i] checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')] sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(replaced_module,
sds, sds,
...@@ -649,37 +595,22 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -649,37 +595,22 @@ def replace_transformer_layer(orig_layer_impl,
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints") print("Saving tp-sharded checkpoints")
torch.save( torch.save(
OrderedDict({ OrderedDict({k: v
k: v for k, v in dict(replaced_module.state_dict()).items()
for k, if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
v in dict(replaced_module.state_dict()).items()
if transformer_name not in k
}),
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
ckpt_config = json.dumps({ ckpt_config = json.dumps({
'type': 'type': ckpt_name,
ckpt_name, 'base_dir': f'{config.save_mp_checkpoint_path}',
'base_dir':
f'{config.save_mp_checkpoint_path}',
'checkpoints': { 'checkpoints': {
"non_tp": "non_tp": ckpt_files,
ckpt_files, "tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
"tp": [
f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions)
for r in range(world_size)
]
}, },
'version': 'version': 1.0,
1.0, 'parallelization': 'tp',
'parallelization': 'tp_size': world_size,
'tp', 'dtype': 'int8' if quantize else ('float16' if fp16 else 'float32')
'tp_size':
world_size,
'dtype':
'int8' if quantize else ('float16' if fp16 else 'float32')
}) })
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
"w") as cfg:
cfg.write(ckpt_config) cfg.write(ckpt_config)
rep_sd = replaced_module.state_dict() rep_sd = replaced_module.state_dict()
...@@ -691,13 +622,9 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -691,13 +622,9 @@ def replace_transformer_layer(orig_layer_impl,
for m in range(num_partitions): for m in range(num_partitions):
torch.save( torch.save(
OrderedDict({ OrderedDict({
k: [rep_sd[k], k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
rep_sd[k].scale] if hasattr(rep_sd[k], for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
'scale') else rep_sd[k] }), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
for k in keys[m * partition_size:(m + 1) * partition_size]
if transformer_name in k
}),
f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
return replaced_module return replaced_module
...@@ -712,6 +639,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False): ...@@ -712,6 +639,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
Returns: Returns:
Updated nn.module with original bert-style transformer layers Updated nn.module with original bert-style transformer layers
""" """
def replace_fn(child, _replace_policy, layer_id): def replace_fn(child, _replace_policy, layer_id):
#from turing.nvidia_modelingpreln import BertLayer #from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(config) orig_module = orig_layer_impl(config)
...@@ -813,9 +741,7 @@ def _replace_module(model, policies, layer_id=0): ...@@ -813,9 +741,7 @@ def _replace_module(model, policies, layer_id=0):
""" """
for name, child in model.named_children(): for name, child in model.named_children():
if child.__class__ in policies: if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child, replaced_module = policies[child.__class__][0](child, policies[child.__class__][-1], layer_id)
policies[child.__class__][-1],
layer_id)
setattr(model, name, replaced_module) setattr(model, name, replaced_module)
if isinstance(model, PipelineModule): if isinstance(model, PipelineModule):
assert hasattr(model, 'forward_funcs'),\ assert hasattr(model, 'forward_funcs'),\
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from .containers import HFGPT2LayerPolicy from .containers import HFGPT2LayerPolicy
from .containers import HFBertLayerPolicy from .containers import HFBertLayerPolicy
from .containers import BLOOMLayerPolicy from .containers import BLOOMLayerPolicy
...@@ -16,16 +18,8 @@ from .containers import VAEPolicy ...@@ -16,16 +18,8 @@ from .containers import VAEPolicy
# transformer-based policies # transformer-based policies
replace_policies = [ replace_policies = [
HFBertLayerPolicy, HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy,
HFGPTNEOLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy
GPTNEOXLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
BLOOMLayerPolicy,
HFOPTLayerPolicy,
HFCLIPLayerPolicy,
HFDistilBertLayerPolicy
] ]
# non-transformer-based policies # non-transformer-based policies
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.utils import log_dist from deepspeed.utils import log_dist
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
import copy import copy
class Experts(torch.nn.Module): class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1, expert_group_name=None): def __init__(self, expert, num_local_experts=1, expert_group_name=None):
super(Experts, self).__init__() super(Experts, self).__init__()
self.deepspeed_experts = torch.nn.ModuleList( self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts self.num_local_experts = num_local_experts
# TODO: revisit allreduce for moe.gate... # TODO: revisit allreduce for moe.gate...
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
...@@ -31,6 +32,7 @@ class MoE(torch.nn.Module): ...@@ -31,6 +32,7 @@ class MoE(torch.nn.Module):
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed). use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size,
expert, expert,
...@@ -65,15 +67,8 @@ class MoE(torch.nn.Module): ...@@ -65,15 +67,8 @@ class MoE(torch.nn.Module):
'Unsupported noisy_gate_policy: ' + noisy_gate_policy 'Unsupported noisy_gate_policy: ' + noisy_gate_policy
experts = Experts(expert, self.num_local_experts, self.expert_group_name) experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
num_experts, min_capacity, noisy_gate_policy, drop_tokens, use_rts),
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts),
experts, experts,
self.expert_group_name, self.expert_group_name,
self.ep_size, self.ep_size,
...@@ -90,20 +85,16 @@ class MoE(torch.nn.Module): ...@@ -90,20 +85,16 @@ class MoE(torch.nn.Module):
def _create_process_groups(self): def _create_process_groups(self):
# Create process group for a layer if needed # Create process group for a layer if needed
if self.expert_group_name not in groups._get_expert_parallel_group_dict(): if self.expert_group_name not in groups._get_expert_parallel_group_dict():
print( print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
f"No existing process group found, creating a new group named: {self.expert_group_name}"
)
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism): if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
# Condition 1 - no groups.mpu means no tensor parallelism # Condition 1 - no groups.mpu means no tensor parallelism
# Condition 2 - disabling expert tensor parallelism on purpose # Condition 2 - disabling expert tensor parallelism on purpose
groups._create_expert_and_data_parallel(self.ep_size) groups._create_expert_and_data_parallel(self.ep_size)
else: else:
# expert tensor parallelism is enabled # expert tensor parallelism is enabled
groups._create_expert_data_and_model_parallel(self.ep_size, groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu)
mpu=groups.mpu)
# Set the group handle for the MOELayer (deepspeed_moe) object # Set the group handle for the MOELayer (deepspeed_moe) object
self.deepspeed_moe._set_ep_group( self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
groups._get_expert_parallel_group(self.expert_group_name))
def forward(self, hidden_states, used_token=None): def forward(self, hidden_states, used_token=None):
""" MoE forward """ MoE forward
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
# The file has been adapted from the following Megatron-LM file: # The file has been adapted from the following Megatron-LM file:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
...@@ -32,14 +33,9 @@ def _gather_tokens(input_, dim=0): ...@@ -32,14 +33,9 @@ def _gather_tokens(input_, dim=0):
# Size and dimension. # Size and dimension.
rank = mpu.get_tensor_model_parallel_rank() rank = mpu.get_tensor_model_parallel_rank()
tensor_list = [ tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
torch.empty_like(input_)
for _ in range(mpu.get_tensor_model_parallel_world_size())
]
tensor_list[rank] = input_ tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list, deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
input_,
group=mpu.get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor. # Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous() output = torch.cat(tensor_list, dim=dim).contiguous()
...@@ -53,7 +49,8 @@ def _drop_tokens(input_, dim=0): ...@@ -53,7 +49,8 @@ def _drop_tokens(input_, dim=0):
total_chunks = mpu.get_tensor_model_parallel_world_size() total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank() this_chunk = mpu.get_tensor_model_parallel_rank()
assert input_.shape[dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})" assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
...@@ -61,6 +58,7 @@ def _drop_tokens(input_, dim=0): ...@@ -61,6 +58,7 @@ def _drop_tokens(input_, dim=0):
class _GatherTokens(torch.autograd.Function): class _GatherTokens(torch.autograd.Function):
"""All gather tokens among the tensor parallel ranks""" """All gather tokens among the tensor parallel ranks"""
@staticmethod @staticmethod
def symbolic(graph, input_, dim): def symbolic(graph, input_, dim):
return _gather_tokens(input_, dim) return _gather_tokens(input_, dim)
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2021 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# The file has been adapted from two fairscale files: # DeepSpeed Team
# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py """
# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py The file has been adapted from two fairscale files:
# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
# We retain the following license from the original files: (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
We retain the following license from the original files:
"""
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# #
...@@ -60,11 +63,9 @@ def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): ...@@ -60,11 +63,9 @@ def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
return x return x
uniform = uniform_map.get(device) uniform = uniform_map.get(device)
if uniform is None: if uniform is None:
uniform = torch.distributions.uniform.Uniform( uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
low=torch.tensor(1.0 - epsilon, high=torch.tensor(1.0 + epsilon,
device=device), device=device)).rsample # type: ignore
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform uniform_map[device] = uniform
return x * uniform(x.shape) return x * uniform(x.shape)
...@@ -87,6 +88,7 @@ from deepspeed import comm as dist ...@@ -87,6 +88,7 @@ from deepspeed import comm as dist
# Based on https://github.com/pytorch/pytorch/pull/40762 # Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function): class _AllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx: Any, ctx: Any,
...@@ -181,25 +183,18 @@ def top1gating(logits: Tensor, ...@@ -181,25 +183,18 @@ def top1gating(logits: Tensor,
noisy_gate_policy: Optional[str] = None, noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True, drop_tokens: bool = True,
use_rts: bool = True, use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Tensor,
Tensor,
Tensor]:
"""Implements Top1Gating on logits.""" """Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample': if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function # everything is in fp32 in this function
gates = F.softmax(logits, dim=1) gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
torch.tensor(capacity_factor),
torch.tensor(min_capacity))
# Create a mask for 1st's expert per token # Create a mask for 1st's expert per token
# noisy gating # noisy gating
indices1_s = torch.argmax( indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
logits_w_noise if noisy_gate_policy == 'RSample' else gates,
dim=1)
num_experts = int(gates.shape[1]) num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts) mask1 = F.one_hot(indices1_s, num_classes=num_experts)
...@@ -225,18 +220,16 @@ def top1gating(logits: Tensor, ...@@ -225,18 +220,16 @@ def top1gating(logits: Tensor,
if use_rts: if use_rts:
uniform = exp_selection_uniform_map.get(logits.device) uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None: if uniform is None:
uniform = torch.distributions.uniform.Uniform( uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
low=torch.tensor(0.0, high=torch.tensor(1.0, device=logits.device)).rsample
device=logits.device),
high=torch.tensor(1.0,
device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape) mask1_rand = mask1 * uniform(mask1.shape)
else: else:
mask1_rand = mask1 mask1_rand = mask1
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." assert logits.shape[
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
top_idx = _top_idx(mask1_rand, capacity) top_idx = _top_idx(mask1_rand, capacity)
...@@ -258,7 +251,13 @@ def top1gating(logits: Tensor, ...@@ -258,7 +251,13 @@ def top1gating(logits: Tensor,
if use_tutel: if use_tutel:
gates1_s = (gates * mask1).sum(dim=1) gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts return l_aux, capacity, num_experts, [
indices1_s,
], [
locations1_s,
], [
gates1_s,
], exp_counts
# Store the capacity location for each token # Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)
...@@ -275,19 +274,12 @@ def top1gating(logits: Tensor, ...@@ -275,19 +274,12 @@ def top1gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(logits: Tensor, def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
capacity_factor: float,
min_capacity: int) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
"""Implements Top2Gating on logits.""" """Implements Top2Gating on logits."""
# everything is in fp32 in this function # everything is in fp32 in this function
gates = F.softmax(logits, dim=1) gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
torch.tensor(capacity_factor * 2),
torch.tensor(min_capacity))
# Create a mask for 1st's expert per token # Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1) indices1_s = torch.argmax(gates, dim=1)
...@@ -393,13 +385,10 @@ class TopKGate(Module): ...@@ -393,13 +385,10 @@ class TopKGate(Module):
self.drop_tokens = drop_tokens self.drop_tokens = drop_tokens
self.use_rts = use_rts self.use_rts = use_rts
def forward( def forward(self,
self, input: torch.Tensor,
input: torch.Tensor, used_token: torch.Tensor = None,
used_token: torch.Tensor = None, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor]: # type: ignore
if self.wall_clock_breakdown: if self.wall_clock_breakdown:
self.timers('TopKGate').start() self.timers('TopKGate').start()
...@@ -413,21 +402,13 @@ class TopKGate(Module): ...@@ -413,21 +402,13 @@ class TopKGate(Module):
logits = self.wg(input_fp32) logits = self.wg(input_fp32)
if self.k == 1: if self.k == 1:
gate_output = top1gating( gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
logits, self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.capacity_factor if self.training else self.eval_capacity_factor, self.drop_tokens, self.use_rts, use_tutel)
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None,
self.drop_tokens,
self.use_rts,
use_tutel)
else: else:
gate_output = top2gating( gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
logits, self.min_capacity)
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
if self.wall_clock_breakdown: if self.wall_clock_breakdown:
self.timers('TopKGate').stop() self.timers('TopKGate').stop()
...@@ -453,6 +434,7 @@ class MOELayer(Base): ...@@ -453,6 +434,7 @@ class MOELayer(Base):
expert (torch.nn.Module): expert (torch.nn.Module):
expert network expert network
""" """
def __init__(self, def __init__(self,
gate: Module, gate: Module,
experts: Module, experts: Module,
...@@ -481,9 +463,8 @@ class MOELayer(Base): ...@@ -481,9 +463,8 @@ class MOELayer(Base):
logger.warning("Tutel optimization requested but not installed. " logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.") "Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1: elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning( logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "
"To enable Tutel optimization, use top-1 instead of top-2 gate. " "Proceeding without Tutel.")
"Proceeding without Tutel.")
def _set_ep_group(self, ep_group): def _set_ep_group(self, ep_group):
self.ep_group = ep_group self.ep_group = ep_group
...@@ -506,18 +487,12 @@ class MOELayer(Base): ...@@ -506,18 +487,12 @@ class MOELayer(Base):
S, M = reshaped_input.size(0), reshaped_input.size(1) S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'): if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher( self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
E,
C,
M,
dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input) dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else: else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm", dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
dispatch_mask.type_as(input[0]),
reshaped_input)
if self.wall_clock_breakdown: if self.wall_clock_breakdown:
self.timers('falltoall').start() self.timers('falltoall').start()
...@@ -538,10 +513,7 @@ class MOELayer(Base): ...@@ -538,10 +513,7 @@ class MOELayer(Base):
self.time_falltoall = self.timers('falltoall').elapsed(reset=False) self.time_falltoall = self.timers('falltoall').elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm # Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
self.num_local_experts,
-1,
d_model)
expert_output = self.experts(dispatched_input) expert_output = self.experts(dispatched_input)
...@@ -555,9 +527,7 @@ class MOELayer(Base): ...@@ -555,9 +527,7 @@ class MOELayer(Base):
self.time_salltoall = self.timers('salltoall').elapsed(reset=False) self.time_salltoall = self.timers('salltoall').elapsed(reset=False)
# Re-shape back: gecm -> ecm # Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
-1,
d_model)
if groups._get_expert_model_parallel_world_size() == 1: if groups._get_expert_model_parallel_world_size() == 1:
# the dropped duplicate tokens need to be gathered on each # the dropped duplicate tokens need to be gathered on each
...@@ -568,9 +538,7 @@ class MOELayer(Base): ...@@ -568,9 +538,7 @@ class MOELayer(Base):
if self.use_tutel: if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M)) combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else: else:
combined_output = einsum("sec,ecm->sm", combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
combine_weights.type_as(input[0]),
expert_output)
a = combined_output.reshape(input[0].shape) a = combined_output.reshape(input[0].shape)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List, Tuple, Dict from typing import List, Tuple, Dict
import torch import torch
...@@ -24,8 +27,7 @@ def is_moe_param(param: torch.Tensor) -> bool: ...@@ -24,8 +27,7 @@ def is_moe_param(param: torch.Tensor) -> bool:
def split_params_into_shared_and_expert_params( def split_params_into_shared_and_expert_params(
params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
torch.nn.Parameter]:
shared_params, expert_params = [], [] shared_params, expert_params = [], []
for p in params: for p in params:
if is_moe_param(p): if is_moe_param(p):
...@@ -36,8 +38,7 @@ def split_params_into_shared_and_expert_params( ...@@ -36,8 +38,7 @@ def split_params_into_shared_and_expert_params(
def split_params_grads_into_shared_and_expert_params( def split_params_grads_into_shared_and_expert_params(
group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
torch.nn.Parameter]:
"""Split grad of parameters into grads of non-expert params """Split grad of parameters into grads of non-expert params
and grads of expert params. This is useful while computing and grads of expert params. This is useful while computing
grad-norms for clipping and overflow detection grad-norms for clipping and overflow detection
...@@ -62,8 +63,7 @@ def split_params_grads_into_shared_and_expert_params( ...@@ -62,8 +63,7 @@ def split_params_grads_into_shared_and_expert_params(
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
max_group_size=178956971 max_group_size=178956971) -> Tuple[Dict]:
) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer """Split parameters into different MoE groups for optimizer
Args: Args:
...@@ -101,8 +101,7 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic ...@@ -101,8 +101,7 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
if ori_key == 'params': if ori_key == 'params':
group_moe[param_group['name']][key][ori_key] = [] group_moe[param_group['name']][key][ori_key] = []
else: else:
group_moe[ group_moe[param_group['name']][key][ori_key] = param_group[ori_key]
param_group['name']][key][ori_key] = param_group[ori_key]
# Assign param # Assign param
for param_group in param_groups: for param_group in param_groups:
new_params = [] new_params = []
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
from pydantic import root_validator from pydantic import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.config_utils import DeepSpeedConfigModel
def get_monitor_config(param_dict): def get_monitor_config(param_dict):
monitor_dict = { monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")}
key: param_dict.get(key,
{})
for key in ("tensorboard",
"wandb",
"csv_monitor")
}
return DeepSpeedMonitorConfig(**monitor_dict) return DeepSpeedMonitorConfig(**monitor_dict)
...@@ -78,10 +71,10 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel): ...@@ -78,10 +71,10 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {} csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """ """ Local CSV output of monitoring data. """
@root_validator @root_validator
def check_enabled(cls, values): def check_enabled(cls, values):
values["enabled"] = False values["enabled"] = False
if (values.get("tensorboard").enabled or values.get("wandb").enabled if (values.get("tensorboard").enabled or values.get("wandb").enabled or values.get("csv_monitor").enabled):
or values.get("csv_monitor").enabled):
values["enabled"] = True values["enabled"] = True
return values return values
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .monitor import Monitor from .monitor import Monitor
import os import os
...@@ -7,6 +10,7 @@ import deepspeed.comm as dist ...@@ -7,6 +10,7 @@ import deepspeed.comm as dist
class csvMonitor(Monitor): class csvMonitor(Monitor):
def __init__(self, csv_config): def __init__(self, csv_config):
super().__init__(csv_config) super().__init__(csv_config)
self.filenames = [] self.filenames = []
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
""" """
Support different forms of monitoring such as wandb and tensorboard Support different forms of monitoring such as wandb and tensorboard
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -8,6 +11,7 @@ import deepspeed.comm as dist ...@@ -8,6 +11,7 @@ import deepspeed.comm as dist
class Monitor(ABC): class Monitor(ABC):
@abstractmethod @abstractmethod
def __init__(self, monitor_config): def __init__(self, monitor_config):
self.monitor_config = monitor_config self.monitor_config = monitor_config
...@@ -23,6 +27,7 @@ from .csv_monitor import csvMonitor ...@@ -23,6 +27,7 @@ from .csv_monitor import csvMonitor
class MonitorMaster(Monitor): class MonitorMaster(Monitor):
def __init__(self, monitor_config): def __init__(self, monitor_config):
super().__init__(monitor_config) super().__init__(monitor_config)
self.tb_monitor = None self.tb_monitor = None
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .utils import check_tb_availability from .utils import check_tb_availability
from .monitor import Monitor from .monitor import Monitor
...@@ -8,6 +11,7 @@ import deepspeed.comm as dist ...@@ -8,6 +11,7 @@ import deepspeed.comm as dist
class TensorBoardMonitor(Monitor): class TensorBoardMonitor(Monitor):
def __init__(self, tensorboard_config): def __init__(self, tensorboard_config):
super().__init__(tensorboard_config) super().__init__(tensorboard_config)
check_tb_availability() check_tb_availability()
...@@ -20,9 +24,7 @@ class TensorBoardMonitor(Monitor): ...@@ -20,9 +24,7 @@ class TensorBoardMonitor(Monitor):
if self.enabled and dist.get_rank() == 0: if self.enabled and dist.get_rank() == 0:
self.get_summary_writer() self.get_summary_writer()
def get_summary_writer(self, def get_summary_writer(self, base=os.path.join(os.path.expanduser("~"), "tensorboard")):
base=os.path.join(os.path.expanduser("~"),
"tensorboard")):
if self.enabled and dist.get_rank() == 0: if self.enabled and dist.get_rank() == 0:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
if self.output_path is not None: if self.output_path is not None:
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
def check_tb_availability(): def check_tb_availability():
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .utils import check_wandb_availability from .utils import check_wandb_availability
from .monitor import Monitor from .monitor import Monitor
...@@ -7,6 +10,7 @@ import deepspeed.comm as dist ...@@ -7,6 +10,7 @@ import deepspeed.comm as dist
class WandbMonitor(Monitor): class WandbMonitor(Monitor):
def __init__(self, wandb_config): def __init__(self, wandb_config):
super().__init__(wandb_config) super().__init__(wandb_config)
check_wandb_availability() check_wandb_availability()
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team''' '''Copyright The Microsoft DeepSpeed Team'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment