Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from torch import nn
from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
......@@ -12,6 +15,7 @@ from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
import torch
import gc
from deepspeed.accelerator import get_accelerator
import re
def load_model_with_checkpoint(r_module,
......@@ -24,6 +28,15 @@ def load_model_with_checkpoint(r_module,
container=None):
error_msgs = []
def prefix_check():
# if keys start with 'model.', don't skip level 0 prefix
for key in sd[0].keys():
if re.match("^model[.]", key):
return False
return True
skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix
def transpose(data):
with torch.no_grad():
data = data.contiguous()
......@@ -40,10 +53,8 @@ def load_model_with_checkpoint(r_module,
if prefix + 'bias' in sd[0].keys():
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
args = None
gc.collect()
......@@ -61,86 +72,62 @@ def load_model_with_checkpoint(r_module,
# set the quantizer number of groups using the checkpoint scale shape
weight_quantizer.num_groups = scale.shape[0]
else:
tmp_data = sd[0][prefix + n].to(
get_accelerator().current_device_name())
tmp_data = sd[0][prefix + n].to(get_accelerator().current_device_name())
scale = None
src_shape = tmp_data.shape
dst_shape = p.shape
inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
if (len(src_shape) == 2 and len(dst_shape) == 2):
if (src_shape[inner_dim] == dst_shape[0]
and src_shape[outer_dim] == dst_shape[1]):
if (src_shape[inner_dim] == dst_shape[0] and src_shape[outer_dim] == dst_shape[1]):
if tmp_data.dtype != torch.int8:
p = weight_quantizer.quantize(
transpose(tmp_data) if weight_quantizer.
q_int8 else tmp_data)
transpose(tmp_data) if weight_quantizer.q_int8 else tmp_data)
else:
p = torch.nn.parameter.Parameter(tmp_data,
requires_grad=False)
p = torch.nn.parameter.Parameter(tmp_data, requires_grad=False)
p.scale = scale
setattr(module, n, p)
else:
dim = inner_dim if src_shape[inner_dim] != dst_shape[
0] else outer_dim
dim = inner_dim if src_shape[inner_dim] != dst_shape[0] else outer_dim
dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
if src_shape[dim] > dst_shape[dim1]:
weight_partition = torch.split(
tmp_data,
dst_shape[dim1],
dim=dim)[rank].to(
get_accelerator().current_device_name())
weight_partition = torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank].to(
get_accelerator().current_device_name())
assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
'''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
scale = scale.view(
-1)[weight_quantizer.num_groups *
(rank + 1):].reshape(
weight_quantizer.num_groups,
-1).contiguous()
scale = scale.view(-1)[weight_quantizer.num_groups * (rank + 1):].reshape(
weight_quantizer.num_groups, -1).contiguous()
else:
assert tmp_data.dtype != torch.int8, \
'''Merging of the checkpoints are not supported when using INT8 checkpoint! \
Please use a as many GPUs as TP-size for the checkpoint'''
all_data = [
sd[j][prefix +
n] if type(sd[j][prefix + n]) is list else
sd[j][prefix + n].to(
get_accelerator().current_device_name())
for j in range(len(sd))
sd[j][prefix + n] if type(sd[j][prefix + n]) is list else sd[j][prefix + n].to(
get_accelerator().current_device_name()) for j in range(len(sd))
]
# Check if the weight tensor is for the QKV parameter
if src_shape[1] == (3 *
src_shape[0]) // ckpt_mp_size:
if src_shape[1] == (3 * src_shape[0]) // ckpt_mp_size:
qkv_size = src_shape[outer_dim] // 3
src_split = [
torch.split(src[0].data,
qkv_size,
dim=outer_dim)
for src in all_data
torch.split(src[0].data, qkv_size, dim=outer_dim) for src in all_data
]
weight_partition = torch.cat([
torch.cat([qkv_s[i] for qkv_s in src_split],
axis=outer_dim)
torch.cat([qkv_s[i] for qkv_s in src_split], axis=outer_dim)
for i in range(len(src_split[0]))
],
dim=dim)
else:
weight_partition = torch.cat([
ad[0].to(
get_accelerator().current_device_name())
if type(ad) is list else ad
for ad in all_data
ad[0].to(get_accelerator().current_device_name())
if type(ad) is list else ad for ad in all_data
],
dim=dim)
if tmp_data.dtype == torch.int8:
scale = torch.cat([
ad[1].to(
get_accelerator().current_device_name())
for ad in all_data
],
dim=dim)
scale = torch.cat(
[ad[1].to(get_accelerator().current_device_name()) for ad in all_data],
dim=dim)
if tmp_data.dtype != torch.int8:
weight_partition = weight_quantizer.quantize(
......@@ -148,9 +135,8 @@ def load_model_with_checkpoint(r_module,
parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(weight_partition)
else:
weight_partition = torch.nn.parameter.Parameter(
weight_partition,
requires_grad=False)
weight_partition = torch.nn.parameter.Parameter(weight_partition,
requires_grad=False)
weight_partition.scale = scale
setattr(module, n, weight_partition)
else:
......@@ -158,42 +144,27 @@ def load_model_with_checkpoint(r_module,
p.data.copy_(tmp_data)
else:
if src_shape[0] > dst_shape[0]:
bias_split = torch.split(
tmp_data,
dst_shape[-1])[rank].to(get_accelerator(
).current_device_name()).contiguous()
bias_split = torch.split(tmp_data, dst_shape[-1])[rank].to(
get_accelerator().current_device_name()).contiguous()
p.data.copy_(bias_split)
else:
# Check if the weight tensor is for the QKV parameter
if src_shape[0] == (3 * r_module.config.hidden_size
) // ckpt_mp_size:
if src_shape[0] == (3 * r_module.config.hidden_size) // ckpt_mp_size:
qkv_size = src_shape[0] // 3
src_split = [
torch.split(sd[j][prefix + n],
qkv_size,
dim=0) for j in range(len(sd))
torch.split(sd[j][prefix + n], qkv_size, dim=0) for j in range(len(sd))
]
p.data.copy_(
torch.cat(
[
torch.cat([
qkv_s[i] for qkv_s in src_split
],
axis=0)
for i in range(len(src_split[0]))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
torch.cat([
torch.cat([qkv_s[i] for qkv_s in src_split], axis=0)
for i in range(len(src_split[0]))
],
dim=0).to(get_accelerator().current_device_name()).contiguous())
else:
p.data.copy_(
torch.cat(
[
sd[j][prefix + n]
for j in range(len(sd))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
torch.cat([sd[j][prefix + n] for j in range(len(sd))],
dim=0).to(get_accelerator().current_device_name()).contiguous())
load_parameters(module, prefix)
for n, child in module.named_children():
......@@ -239,20 +210,16 @@ def load_model_with_checkpoint(r_module,
setattr(module, name, child)
continue
child_params = list(child.parameters())
if len(child_params) > 0 and (child_params[0].numel() == 0
or child_params[0].is_meta):
if len(child_params) > 0 and (child_params[0].numel() == 0 or child_params[0].is_meta):
if child.weight.is_meta:
ds_shape = child.weight.shape
else:
ds_shape = child.weight.ds_shape
if child.__class__ is nn.LayerNorm:
child = Normalize(dim=ds_shape[-1],
dtype=child.weight.dtype,
eps=child.eps)
child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
setattr(module, name, child)
elif child.__class__ is nn.Linear:
child = LinearLayer(weight_shape=child.weight.shape,
bias=child.bias)
child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias)
setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape)
......@@ -261,8 +228,7 @@ def load_model_with_checkpoint(r_module,
ds_id = None
if hasattr(child.weight, 'ds_id'):
ds_id = child.weight.ds_id
child = EmbeddingLayer(weight_shape=ds_shape,
dtype=child.weight.dtype)
child = EmbeddingLayer(weight_shape=ds_shape, dtype=child.weight.dtype)
if ds_id is not None:
all_ds_ids[ds_id] = child.weight
setattr(module, name, child)
......@@ -270,7 +236,7 @@ def load_model_with_checkpoint(r_module,
else:
load_module_recursive(
child,
prefix if (level == 0 and ckpt_type == 'pp') and container.policy.use_load_prefix else \
prefix if (level == 0 and ckpt_type == 'pp') and skip_level_0_prefix else \
prefix + name + '.',
level + 1)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
......@@ -18,34 +21,25 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
Returns:
Updated nn.module with quantized transformer layers
"""
def quantize_weight(weight):
return weight.to(torch.int8)
def megatron_layer_quantize(layer):
layer.attention.query_key_value.weight.data = quantize_weight(
layer.attention.query_key_value.weight.data)
layer.attention.dense.weight.data = quantize_weight(
layer.attention.dense.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(
layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(
layer.mlp.dense_4h_to_h.weight.data)
layer.attention.query_key_value.weight.data = quantize_weight(layer.attention.query_key_value.weight.data)
layer.attention.dense.weight.data = quantize_weight(layer.attention.dense.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(layer.mlp.dense_4h_to_h.weight.data)
def bert_layer_quantize(layer):
layer.attention.self.query.weight.data = quantize_weight(
layer.attention.self.query.weight.data)
layer.attention.self.key.weight.data = quantize_weight(
layer.attention.self.key.weight.data)
layer.attention.self.value.weight.data = quantize_weight(
layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(
layer.attention.output.dense.weight.data)
layer.attention.self.query.weight.data = quantize_weight(layer.attention.self.query.weight.data)
layer.attention.self.key.weight.data = quantize_weight(layer.attention.self.key.weight.data)
layer.attention.self.value.weight.data = quantize_weight(layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(layer.attention.output.dense.weight.data)
if preln:
layer.intermediate.dense_act.weight.data = quantize_weight(
layer.intermediate.dense_act.weight.data)
layer.intermediate.dense_act.weight.data = quantize_weight(layer.intermediate.dense_act.weight.data)
else:
layer.intermediate.dense.weight.data = quantize_weight(
layer.intermediate.dense.weight.data)
layer.intermediate.dense.weight.data = quantize_weight(layer.intermediate.dense.weight.data)
layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
def quantize_fn(child):
......@@ -58,9 +52,7 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
return child
return quantize_module(model=model,
orig_class=orig_layer_impl,
quantize_fn=quantize_fn)
return quantize_module(model=model, orig_class=orig_layer_impl, quantize_fn=quantize_fn)
def quantize_module(model, orig_class, quantize_fn):
......
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from abc import ABC, abstractmethod
from deepspeed.utils.types import ActivationFuncType
import torch
......@@ -70,7 +72,7 @@ class TransformerPolicy(DSPolicy):
self.split_qkv = split_qkv
@abstractmethod
def attention(self):
def attention(self, enable_training=False):
"""
Returns attention qkv and dense parameters
weight: (3*hidden, hidden) and (hidden, hidden)
......@@ -78,6 +80,13 @@ class TransformerPolicy(DSPolicy):
"""
raise NotImplementedError
@abstractmethod
def get_q_k_v(self):
"""
return all q,k,v parameters without merging them together
"""
raise NotImplementedError
@abstractmethod
def get_hidden_heads(self):
"""
......@@ -103,6 +112,14 @@ class TransformerPolicy(DSPolicy):
"""
raise NotImplementedError
@abstractmethod
def get_lora_params(self):
"""
Returns lora parameters used in transformer layer
"""
raise NotImplementedError
# TODO (lekurile): This function exists in base container as well, consolidate as some point
def transpose(data):
......@@ -124,15 +141,10 @@ def _transpose(x, heads=1, mp_replace=None):
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=-1)
if len(q.shape) > 2:
new_shape = (q.shape[0], ) + (-1, )
return torch.cat((q.reshape(new_shape),
k.reshape(new_shape),
v.reshape(new_shape)),
return torch.cat((q.reshape(new_shape), k.reshape(new_shape), v.reshape(new_shape)),
dim=outer_dim).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
......@@ -156,19 +168,14 @@ def maybe_copy(module,
else:
dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter(
_transpose(dst,
heads=heads,
mp_replace=mp_replace).contiguous())
dst = torch.nn.parameter.Parameter(_transpose(dst, heads=heads, mp_replace=mp_replace).contiguous())
else:
if split_qkv:
dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())), int8=weight_quantizer.q_int8)
else:
if qkv and megatron_v2:
tmp = _transpose(transpose(tmp),
heads=heads,
mp_replace=mp_replace).contiguous()
tmp = _transpose(transpose(tmp), heads=heads, mp_replace=mp_replace).contiguous()
if weight_quantizer.q_int8:
tmp = transpose(tmp)
dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
......@@ -177,13 +184,7 @@ def maybe_copy(module,
# Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module,
sd,
weight_quantizer,
mp_replace,
dst_name,
src_names,
split_qkv=False):
def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names, split_qkv=False):
if src_names[0] in sd:
q = sd[src_names[0]]
k = sd[src_names[1]]
......@@ -203,3 +204,19 @@ def maybe_copy_qkv(module,
dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.to(get_accelerator().device_name()) if weight_quantizer.q_int8 else \
transpose(qkv_data)), int8=weight_quantizer.q_int8)
setattr(module, dst_name, dst)
def pack_lora_weights(p):
return [
p.lora_right_weight, \
p.lora_left_weight, \
p.lora_scaling
]
def maybe_get_lora(p):
if hasattr(p, 'lora_right_weight'):
lora_param = pack_lora_weights(p)
else:
lora_param = []
return lora_param
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
......@@ -23,6 +26,7 @@ from .utils import policy_to_ds_container
class ReplaceWithTensorSlicing:
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
if mp_group is not None:
self.gpu_index = dist.get_rank(group=mp_group)
......@@ -38,7 +42,7 @@ class ReplaceWithTensorSlicing:
for merging your checkpoints before replacing the transformer layer with\
inference-kernels'
def qkv_copy(self, dst, src, int8=False):
def qkv_copy(self, dst, src, int8=False, allocat_tensor=False):
if src is None:
return src
src_shape = src.shape
......@@ -47,6 +51,9 @@ class ReplaceWithTensorSlicing:
outer_dim = 0 if int8 else -1
inner_dim = -1 if int8 else 0
if allocat_tensor:
dst = torch.empty_like(dst)
src_split = torch.split(src.data, src.shape[outer_dim] // 3, dim=outer_dim)
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[outer_dim] == dst_shape[self.out_dim]:
......@@ -55,82 +62,59 @@ class ReplaceWithTensorSlicing:
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
if self.out_dim == 1:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
qkv_size = dst_shape[self.out_dim] // 3
qkv_split = [
torch.split(src_s,
qkv_size,
dim=outer_dim) for src_s in src_split
]
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=outer_dim) for i in range(len(qkv_split[0]))
]
dst = dst.reshape(-1).data.copy_(
weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
else:
dst.data.copy_(src_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
qkv_size = dst_shape[self.out_dim] // 3
qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
]
dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
else:
if src_shape[0] == dst_shape[0]:
return torch.nn.parameter.Parameter(src)
if self.out_dim == 1:
qkv_size = dst_shape[0] // 3
qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
bias_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=0) for i in range(len(qkv_split[0]))
]
dst.data.copy_(bias_split[self.gpu_index].contiguous())
else:
dst.data.copy_(src_split[self.gpu_index].contiguous())
qkv_size = dst_shape[0] // 3
qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
dst.data.copy_(bias_split[self.gpu_index].contiguous())
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
def copy(self, dst, src, int8=False):
def copy(self, dst, src, int8=False, allocat_tensor=False):
if src is None:
return src
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
if allocat_tensor:
dst = torch.empty_like(dst)
outer_dim = 0 if int8 else 1
inner_dim = 1 if int8 else 0
src_shape = src.shape
dst_shape = dst.shape
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[inner_dim] == dst_shape[
self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
else:
if src_shape[inner_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
weight_split = torch.split(
src,
dst_shape[self.in_dim],
dim=inner_dim)[self.gpu_index].contiguous()
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
else:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
weight_split = torch.split(
src.data,
dst_shape[self.out_dim],
dim=outer_dim)[self.gpu_index].contiguous()
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(
weight_split.shape)
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
else:
if src_shape[0] == dst_shape[0]:
dst.data.copy_(src)
dst = src
else:
bias_split = torch.split(src.data,
dst_shape[-1])[self.gpu_index].contiguous()
dst.data.copy_(bias_split)
dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
......@@ -150,6 +134,7 @@ def get_transformer_name(replaced_module):
class GroupQuantizer:
def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
self.group_size = group_size
self.num_bits = num_bits
......@@ -163,8 +148,7 @@ class GroupQuantizer:
inputs.scale = torch.empty(1)
return inputs
q_range = 2**self.num_bits
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[
0] // self.group_size
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
inputs = inputs.to(get_accelerator().current_device_name())
input_flat = inputs.reshape(num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
......@@ -174,31 +158,14 @@ class GroupQuantizer:
inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
out = torch.nn.Parameter(inputs_q, requires_grad=False)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [
inputs_split[i].reshape(num_groups,
-1).contiguous() for i in range(2)
]
input_min = [
torch.min(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
input_max = [
torch.max(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
scale1 = [
(torch.max(input_min[i].abs(),
input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
for i in range(2)
]
out.scale = torch.cat([scale.squeeze().unsqueeze(0),
scale1[0],
scale1[1]],
dim=0).reshape(num_groups,
-1).contiguous()
input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
for i in range(2)]
out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
-1).contiguous()
return out
......@@ -211,6 +178,7 @@ def _module_match(module):
def generic_injection(module, fp16=False, enable_cuda_graph=True):
def replace_attn(child, policy):
policy_attn = policy.attention(child)
if policy_attn is None:
......@@ -246,8 +214,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
attn_module.attn_qkvb = None
attn_module.attn_ow.data = transpose(attn_ow.data)
attn_module.attn_ob.data.copy_(
attn_ob.data.to(get_accelerator().current_device_name()))
attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
return attn_module
def replace_attn_block(child, policy):
......@@ -262,7 +229,10 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
try:
import diffusers
cross_attention = diffusers.models.attention.CrossAttention
if hasattr(diffusers.models.attention, 'CrossAttention'):
cross_attention = diffusers.models.attention.CrossAttention
else:
cross_attention = diffusers.models.attention_processor.Attention
attention_block = diffusers.models.attention.BasicTransformerBlock
new_policies = {
cross_attention: replace_attn,
......@@ -278,8 +248,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
# triangular_masking=True,
# max_out_tokens=8192)
from ..model_implementations.transformers.clip_encoder import DSClipEncoder
cg_encoder = DSClipEncoder(module.text_encoder,
enable_cuda_graph=enable_cuda_graph)
cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
setattr(module, 'text_encoder', cg_encoder)
for name in module.__dict__.keys():
sub_module = getattr(module, name)
......@@ -291,13 +260,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
for name, child in module.named_children():
_replace_module(child, policy)
if child.__class__ in new_policies:
replaced_module = new_policies[child.__class__](child,
policy)
replaced_module = new_policies[child.__class__](child, policy)
setattr(module, name, replaced_module)
_replace_module(sub_module, policy)
new_module = policy.apply(sub_module,
enable_cuda_graph=enable_cuda_graph)
new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module)
......@@ -305,11 +272,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
container_g = None
def replace_transformer_layer(orig_layer_impl,
model,
checkpoint_dict,
config,
model_config):
def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
......@@ -334,15 +297,10 @@ def replace_transformer_layer(orig_layer_impl,
seed = -1
local_rank = -1
mp_replace = ReplaceWithTensorSlicing(
mp_group=config.tensor_parallel.tp_group,
mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
def replace_with_policy(child,
policy_cls,
triangular_masking,
inference=False,
layer_id=0):
def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
policy = policy_cls(child, inference=inference)
if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set
......@@ -364,8 +322,7 @@ def replace_transformer_layer(orig_layer_impl,
_container.set_moe(moe)
# 2. Set the tensor parallelism config
_container.set_tensor_parallel_config(config.tensor_parallel.tp_size,
config.tensor_parallel.tp_group)
_container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
# 3. Initialize tensors
_container.initialize_tensors()
......@@ -411,25 +368,21 @@ def replace_transformer_layer(orig_layer_impl,
if name in all_reduce_linears:
new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0],
(weight_shape[0] if conv_linear_layer else weight_shape[1]) //
mp_size,
(weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size,
),
device=child.weight.device,
dtype=child.weight.dtype)
if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = mp_replace.copy(new_weight, child.weight.data)
new_bias = torch.empty((weight_shape[0]),
device=child.weight.device,
dtype=child.weight.dtype)
new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype)
if child.bias is not None:
new_bias.data.copy_(child.bias.data)
return LinearAllreduce(data, child.bias if child.bias is None else \
torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
else:
new_weight = torch.empty((
(weight_shape[1] if conv_linear_layer else weight_shape[0]) //
mp_size,
(weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size,
weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
),
device=child.weight.device,
......@@ -441,43 +394,54 @@ def replace_transformer_layer(orig_layer_impl,
new_bias = torch.empty((weight_shape[0] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
bias_data = None if child.bias is None else mp_replace.copy(
new_bias,
child.bias.data).to(get_accelerator().current_device_name())
return LinearLayer(weight=data.to(
get_accelerator().current_device_name()),
bias=bias_data)
bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to(
get_accelerator().current_device_name())
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)
def _slice_embedding(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0],
child.weight.shape[1] // mp_size),
new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
data = mp_replace.copy(new_weight,
child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0],
child.weight.shape[1] // mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
return new_embedding
def update_mp_params(child):
if hasattr(child, 'n_heads'):
assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(
child.n_heads, mp_size)
child.n_heads = child.n_heads // mp_size
if hasattr(child, 'inner_dim'):
assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format(
child.inner_dim, mp_size)
child.inner_dim = child.inner_dim // mp_size
if hasattr(child, 'num_heads'):
assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format(
child.num_heads, mp_size)
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format(
child.num_attention_heads, mp_size)
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'):
assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format(
child.num_attn_heads, mp_size)
child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'):
assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format(
child.all_head_size, mp_size)
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format(
child.embed_dim, mp_size)
child.embed_dim = child.embed_dim // mp_size
if hasattr(child, 'hidden_size'):
assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(
child.hidden_size, mp_size)
child.hidden_size = child.hidden_size // mp_size
conv_linear_layer = False
......@@ -499,12 +463,8 @@ def replace_transformer_layer(orig_layer_impl,
def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children():
if child.__class__ in linear_policies:
setattr(
r_module,
name,
linear_policies[child.__class__](child,
prev_name + '.' + name,
conv_linear_layer))
setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
conv_linear_layer))
else:
update_mp_params(child)
_replace_module(child, name)
......@@ -551,15 +511,10 @@ def replace_transformer_layer(orig_layer_impl,
base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
if ckpt_type == 'pp' and type(checkpoint) is list:
pbar = tqdm.tqdm(total=len(checkpoint),
desc=f"Loading {len(checkpoint)} checkpoint shards")
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
sd = [
torch.load(os.path.join(base_dir1,
checkpoint[i]),
map_location='cpu')
]
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
......@@ -574,22 +529,15 @@ def replace_transformer_layer(orig_layer_impl,
tp_split_size = (world_size / ckpt_mp_size)
sd_offset = int(rank / tp_split_size)
sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
pbar = tqdm.tqdm(total=num_checkpoints,
desc=f"Loading {num_checkpoints} checkpoint shards")
pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
for i in range(num_checkpoints):
pbar.update(1)
ckpt_index = i * ckpt_mp_size + sd_offset
ckpt_files = [
os.path.join(base_dir1,
ckpt_list[ckpt_index +
j]) if base_dir1 else ckpt_list[ckpt_index +
j]
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count)
]
sds = [
torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files
]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
......@@ -602,15 +550,13 @@ def replace_transformer_layer(orig_layer_impl,
gc.collect()
if "non_tp" in checkpoint:
pbar = tqdm.tqdm(
total=len(checkpoint["non_tp"]),
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
for i in range(len(checkpoint["non_tp"])):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i]
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sds,
......@@ -649,37 +595,22 @@ def replace_transformer_layer(orig_layer_impl,
if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints")
torch.save(
OrderedDict({
k: v
for k,
v in dict(replaced_module.state_dict()).items()
if transformer_name not in k
}),
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
OrderedDict({k: v
for k, v in dict(replaced_module.state_dict()).items()
if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
ckpt_config = json.dumps({
'type':
ckpt_name,
'base_dir':
f'{config.save_mp_checkpoint_path}',
'type': ckpt_name,
'base_dir': f'{config.save_mp_checkpoint_path}',
'checkpoints': {
"non_tp":
ckpt_files,
"tp": [
f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions)
for r in range(world_size)
]
"non_tp": ckpt_files,
"tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
},
'version':
1.0,
'parallelization':
'tp',
'tp_size':
world_size,
'dtype':
'int8' if quantize else ('float16' if fp16 else 'float32')
'version': 1.0,
'parallelization': 'tp',
'tp_size': world_size,
'dtype': 'int8' if quantize else ('float16' if fp16 else 'float32')
})
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json",
"w") as cfg:
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
cfg.write(ckpt_config)
rep_sd = replaced_module.state_dict()
......@@ -691,13 +622,9 @@ def replace_transformer_layer(orig_layer_impl,
for m in range(num_partitions):
torch.save(
OrderedDict({
k: [rep_sd[k],
rep_sd[k].scale] if hasattr(rep_sd[k],
'scale') else rep_sd[k]
for k in keys[m * partition_size:(m + 1) * partition_size]
if transformer_name in k
}),
f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
}), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
return replaced_module
......@@ -712,6 +639,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child, _replace_policy, layer_id):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(config)
......@@ -813,9 +741,7 @@ def _replace_module(model, policies, layer_id=0):
"""
for name, child in model.named_children():
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id)
replaced_module = policies[child.__class__][0](child, policies[child.__class__][-1], layer_id)
setattr(model, name, replaced_module)
if isinstance(model, PipelineModule):
assert hasattr(model, 'forward_funcs'),\
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .containers import HFGPT2LayerPolicy
from .containers import HFBertLayerPolicy
from .containers import BLOOMLayerPolicy
......@@ -16,16 +18,8 @@ from .containers import VAEPolicy
# transformer-based policies
replace_policies = [
HFBertLayerPolicy,
HFGPTNEOLayerPolicy,
GPTNEOXLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
BLOOMLayerPolicy,
HFOPTLayerPolicy,
HFCLIPLayerPolicy,
HFDistilBertLayerPolicy
HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy,
HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy
]
# non-transformer-based policies
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.utils import log_dist
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import copy
class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
super(Experts, self).__init__()
self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
# TODO: revisit allreduce for moe.gate...
......
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
......@@ -31,6 +32,7 @@ class MoE(torch.nn.Module):
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
"""
def __init__(self,
hidden_size,
expert,
......@@ -65,15 +67,8 @@ class MoE(torch.nn.Module):
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size,
num_experts,
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts),
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
experts,
self.expert_group_name,
self.ep_size,
......@@ -90,20 +85,16 @@ class MoE(torch.nn.Module):
def _create_process_groups(self):
# Create process group for a layer if needed
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
print(
f"No existing process group found, creating a new group named: {self.expert_group_name}"
)
print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
# Condition 1 - no groups.mpu means no tensor parallelism
# Condition 2 - disabling expert tensor parallelism on purpose
groups._create_expert_and_data_parallel(self.ep_size)
else:
# expert tensor parallelism is enabled
groups._create_expert_data_and_model_parallel(self.ep_size,
mpu=groups.mpu)
groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu)
# Set the group handle for the MOELayer (deepspeed_moe) object
self.deepspeed_moe._set_ep_group(
groups._get_expert_parallel_group(self.expert_group_name))
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
def forward(self, hidden_states, used_token=None):
""" MoE forward
......
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# The file has been adapted from the following Megatron-LM file:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
......@@ -32,14 +33,9 @@ def _gather_tokens(input_, dim=0):
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()
tensor_list = [
torch.empty_like(input_)
for _ in range(mpu.get_tensor_model_parallel_world_size())
]
tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list,
input_,
group=mpu.get_tensor_model_parallel_group())
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
......@@ -53,7 +49,8 @@ def _drop_tokens(input_, dim=0):
total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank()
assert input_.shape[dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
......@@ -61,6 +58,7 @@ def _drop_tokens(input_, dim=0):
class _GatherTokens(torch.autograd.Function):
"""All gather tokens among the tensor parallel ranks"""
@staticmethod
def symbolic(graph, input_, dim):
return _gather_tokens(input_, dim)
......
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# The file has been adapted from two fairscale files:
# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
# We retain the following license from the original files:
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
The file has been adapted from two fairscale files:
(1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
(2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
We retain the following license from the original files:
"""
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
......@@ -60,11 +63,9 @@ def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - epsilon,
device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)
......@@ -87,6 +88,7 @@ from deepspeed import comm as dist
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
......@@ -181,25 +183,18 @@ def top1gating(logits: Tensor,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates,
torch.tensor(capacity_factor),
torch.tensor(min_capacity))
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(
logits_w_noise if noisy_gate_policy == 'RSample' else gates,
dim=1)
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
......@@ -225,18 +220,16 @@ def top1gating(logits: Tensor,
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0,
device=logits.device),
high=torch.tensor(1.0,
device=logits.device)).rsample
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
high=torch.tensor(1.0, device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
assert logits.shape[
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
top_idx = _top_idx(mask1_rand, capacity)
......@@ -258,7 +251,13 @@ def top1gating(logits: Tensor,
if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
return l_aux, capacity, num_experts, [
indices1_s,
], [
locations1_s,
], [
gates1_s,
], exp_counts
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
......@@ -275,19 +274,12 @@ def top1gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates,
torch.tensor(capacity_factor * 2),
torch.tensor(min_capacity))
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
......@@ -393,13 +385,10 @@ class TopKGate(Module):
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(
self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor]: # type: ignore
def forward(self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.wall_clock_breakdown:
self.timers('TopKGate').start()
......@@ -413,21 +402,13 @@ class TopKGate(Module):
logits = self.wg(input_fp32)
if self.k == 1:
gate_output = top1gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None,
self.drop_tokens,
self.use_rts,
use_tutel)
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
else:
gate_output = top2gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
if self.wall_clock_breakdown:
self.timers('TopKGate').stop()
......@@ -453,6 +434,7 @@ class MOELayer(Base):
expert (torch.nn.Module):
expert network
"""
def __init__(self,
gate: Module,
experts: Module,
......@@ -481,9 +463,8 @@ class MOELayer(Base):
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning(
"To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")
logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")
def _set_ep_group(self, ep_group):
self.ep_group = ep_group
......@@ -506,18 +487,12 @@ class MOELayer(Base):
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E,
C,
M,
dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm",
dispatch_mask.type_as(input[0]),
reshaped_input)
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
if self.wall_clock_breakdown:
self.timers('falltoall').start()
......@@ -538,10 +513,7 @@ class MOELayer(Base):
self.time_falltoall = self.timers('falltoall').elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size,
self.num_local_experts,
-1,
d_model)
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
......@@ -555,9 +527,7 @@ class MOELayer(Base):
self.time_salltoall = self.timers('salltoall').elapsed(reset=False)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts,
-1,
d_model)
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
if groups._get_expert_model_parallel_world_size() == 1:
# the dropped duplicate tokens need to be gathered on each
......@@ -568,9 +538,7 @@ class MOELayer(Base):
if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else:
combined_output = einsum("sec,ecm->sm",
combine_weights.type_as(input[0]),
expert_output)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
a = combined_output.reshape(input[0].shape)
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List, Tuple, Dict
import torch
......@@ -24,8 +27,7 @@ def is_moe_param(param: torch.Tensor) -> bool:
def split_params_into_shared_and_expert_params(
params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter,
torch.nn.Parameter]:
params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
shared_params, expert_params = [], []
for p in params:
if is_moe_param(p):
......@@ -36,8 +38,7 @@ def split_params_into_shared_and_expert_params(
def split_params_grads_into_shared_and_expert_params(
group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter,
torch.nn.Parameter]:
group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
"""Split grad of parameters into grads of non-expert params
and grads of expert params. This is useful while computing
grad-norms for clipping and overflow detection
......@@ -62,8 +63,7 @@ def split_params_grads_into_shared_and_expert_params(
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
max_group_size=178956971
) -> Tuple[Dict]:
max_group_size=178956971) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Args:
......@@ -101,8 +101,7 @@ def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dic
if ori_key == 'params':
group_moe[param_group['name']][key][ori_key] = []
else:
group_moe[
param_group['name']][key][ori_key] = param_group[ori_key]
group_moe[param_group['name']][key][ori_key] = param_group[ori_key]
# Assign param
for param_group in param_groups:
new_params = []
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
'''Copyright The Microsoft DeepSpeed Team'''
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from pydantic import root_validator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
def get_monitor_config(param_dict):
monitor_dict = {
key: param_dict.get(key,
{})
for key in ("tensorboard",
"wandb",
"csv_monitor")
}
monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")}
return DeepSpeedMonitorConfig(**monitor_dict)
......@@ -78,10 +71,10 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel):
csv_monitor: CSVConfig = {}
""" Local CSV output of monitoring data. """
@root_validator
def check_enabled(cls, values):
values["enabled"] = False
if (values.get("tensorboard").enabled or values.get("wandb").enabled
or values.get("csv_monitor").enabled):
if (values.get("tensorboard").enabled or values.get("wandb").enabled or values.get("csv_monitor").enabled):
values["enabled"] = True
return values
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .monitor import Monitor
import os
......@@ -7,6 +10,7 @@ import deepspeed.comm as dist
class csvMonitor(Monitor):
def __init__(self, csv_config):
super().__init__(csv_config)
self.filenames = []
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Support different forms of monitoring such as wandb and tensorboard
Support different forms of monitoring such as wandb and tensorboard
"""
from abc import ABC, abstractmethod
......@@ -8,6 +11,7 @@ import deepspeed.comm as dist
class Monitor(ABC):
@abstractmethod
def __init__(self, monitor_config):
self.monitor_config = monitor_config
......@@ -23,6 +27,7 @@ from .csv_monitor import csvMonitor
class MonitorMaster(Monitor):
def __init__(self, monitor_config):
super().__init__(monitor_config)
self.tb_monitor = None
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .utils import check_tb_availability
from .monitor import Monitor
......@@ -8,6 +11,7 @@ import deepspeed.comm as dist
class TensorBoardMonitor(Monitor):
def __init__(self, tensorboard_config):
super().__init__(tensorboard_config)
check_tb_availability()
......@@ -20,9 +24,7 @@ class TensorBoardMonitor(Monitor):
if self.enabled and dist.get_rank() == 0:
self.get_summary_writer()
def get_summary_writer(self,
base=os.path.join(os.path.expanduser("~"),
"tensorboard")):
def get_summary_writer(self, base=os.path.join(os.path.expanduser("~"), "tensorboard")):
if self.enabled and dist.get_rank() == 0:
from torch.utils.tensorboard import SummaryWriter
if self.output_path is not None:
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
def check_tb_availability():
......
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .utils import check_wandb_availability
from .monitor import Monitor
......@@ -7,6 +10,7 @@ import deepspeed.comm as dist
class WandbMonitor(Monitor):
def __init__(self, wandb_config):
super().__init__(wandb_config)
check_wandb_availability()
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment