Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Create a container object to save model-specific tensors using the policy file above. # Create a container object to save model-specific tensors using the policy file above.
from abc import ABC from abc import ABC
...@@ -15,6 +18,7 @@ class BaseConvolutionContainer(ABC): ...@@ -15,6 +18,7 @@ class BaseConvolutionContainer(ABC):
class BaseTransformerContainer(ABC): class BaseTransformerContainer(ABC):
def __init__(self, policy, config, model_config, layer_id, child): def __init__(self, policy, config, model_config, layer_id, child):
self.policy = policy self.policy = policy
self.config = config self.config = config
...@@ -30,28 +34,22 @@ class BaseTransformerContainer(ABC): ...@@ -30,28 +34,22 @@ class BaseTransformerContainer(ABC):
self.hidden_size = None self.hidden_size = None
self.num_attention_heads = None self.num_attention_heads = None
self.mp_size = self.config.tensor_parallel.tp_size self.mp_size = self.config.tensor_parallel.tp_size
self.pre_layer_norm = self.policy.pre_attn_norm self.pre_layer_norm = self.model_config.do_layer_norm_before if \
hasattr(self.model_config, 'do_layer_norm_before') else self.policy.pre_attn_norm
self.fp16 = False self.fp16 = False
self.attn_linear_layer = self.policy.linear_layer self.attn_linear_layer = self.policy.linear_layer
self.mlp_linear_layer = self.policy.linear_layer self.mlp_linear_layer = self.policy.linear_layer
self.layer_norm_eps = self.model_config.layer_norm_eps if \
hasattr(self.model_config, 'layer_norm_eps') else (self.model_config.layer_norm_epsilon if \
hasattr(self.model_config, 'layer_norm_epsilon') else self.model_config.layernorm_epsilon if \
hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12)
self.return_tuple = self.config.return_tuple self.return_tuple = self.config.return_tuple
self.triangular_masking = True self.triangular_masking = True
self.local_attention = ((self.model_config.attention_layers[self.layer_id] self.local_attention = ((self.model_config.attention_layers[self.layer_id] == "local") if hasattr(
== "local") if hasattr(self.model_config, self.model_config, 'attention_layers') else False)
'attention_layers') else False)
self.window_size = getattr(self.model_config, "window_size", 1) self.window_size = getattr(self.model_config, "window_size", 1)
self.mlp_act_func_type = self.policy.mlp_act_func_type self.mlp_act_func_type = self.policy.mlp_act_func_type
self.training_mp_size = self.config.training_mp_size self.training_mp_size = self.config.training_mp_size
self.bigscience_bloom = False self.bigscience_bloom = False
self.max_out_tokens = self.config.max_out_tokens self.max_out_tokens = self.config.max_out_tokens
self.scale_attn_by_inverse_layer_idx = getattr( self.min_out_tokens = self.config.min_out_tokens
self.config, self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False)
"scale_attn_by_inverse_layer_idx",
False)
self.use_mup = self.policy.use_mup self.use_mup = self.policy.use_mup
self.return_single_tuple = False self.return_single_tuple = False
self.rotary_dim = self.model_config.rotary_dim if hasattr(self.model_config, 'rotary_dim') \ self.rotary_dim = self.model_config.rotary_dim if hasattr(self.model_config, 'rotary_dim') \
...@@ -75,6 +73,8 @@ class BaseTransformerContainer(ABC): ...@@ -75,6 +73,8 @@ class BaseTransformerContainer(ABC):
self.input_nw = None self.input_nw = None
self.input_nb = None self.input_nb = None
self.mp_group = None
def create_ds_model_config(self): def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads()) self.set_hidden_heads(*self.policy.get_hidden_heads())
assert self.num_attention_heads % self.mp_size == 0,\ assert self.num_attention_heads % self.mp_size == 0,\
...@@ -84,11 +84,11 @@ class BaseTransformerContainer(ABC): ...@@ -84,11 +84,11 @@ class BaseTransformerContainer(ABC):
self.ds_model_config = DeepSpeedInferenceConfig( self.ds_model_config = DeepSpeedInferenceConfig(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
heads=self.num_attention_heads, heads=self.num_attention_heads,
layer_norm_eps=self.layer_norm_eps, layer_norm_eps=self.layernorm_epsilon,
fp16=self.fp16, fp16=self.fp16,
pre_layer_norm=self.pre_layer_norm, pre_layer_norm=self.pre_layer_norm,
mp_size=self.mp_size, mp_size=self.mp_size,
q_int8=self.quantize, q_int8=self.quantize if hasattr(self, 'quantize') else False,
return_tuple=self.return_tuple, return_tuple=self.return_tuple,
triangular_masking=self.triangular_masking, triangular_masking=self.triangular_masking,
local_attention=self.local_attention, local_attention=self.local_attention,
...@@ -99,18 +99,24 @@ class BaseTransformerContainer(ABC): ...@@ -99,18 +99,24 @@ class BaseTransformerContainer(ABC):
training_mp_size=self.training_mp_size, training_mp_size=self.training_mp_size,
bigscience_bloom=self.bigscience_bloom, bigscience_bloom=self.bigscience_bloom,
max_out_tokens=self.max_out_tokens, max_out_tokens=self.max_out_tokens,
min_out_tokens=self.min_out_tokens,
scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx, scale_attn_by_inverse_layer_idx=self.scale_attn_by_inverse_layer_idx,
use_mup=self.use_mup, use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple, return_single_tuple=self.return_single_tuple,
) set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode)
return self.ds_model_config return self.ds_model_config
def initialize_tensors(self): def initialize_tensors(self, enable_training=False):
# Set the tensors from policy (user module) to container (DS module) # Set the tensors from policy (user module) to container (DS module)
self.set_attention(*self.policy.attention()) self.set_attention(*self.policy.attention(enable_training=enable_training))
self.set_mlp(*self.policy.mlp()) self.set_mlp(*self.policy.mlp())
self.set_layernorm(*self.policy.layernorm()) self.set_layernorm(*self.policy.layernorm())
self.set_lora_params(self.policy.get_lora_params())
self.q_k_v = self.policy.get_q_k_v()
if self.q_k_v is not None:
self.set_q_k_v(*self.q_k_v)
def convert_to_required_dtype(self, dtype): def convert_to_required_dtype(self, dtype):
# Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy # Note: converting tensors to fp16 requires that we do it in-place using self.__dict__ and not make a list/dict copy
...@@ -138,9 +144,10 @@ class BaseTransformerContainer(ABC): ...@@ -138,9 +144,10 @@ class BaseTransformerContainer(ABC):
self.quantize = quantize self.quantize = quantize
self.quantizer = quantizer self.quantizer = quantizer
def set_hidden_heads(self, hidden_size, num_attention_heads): def set_hidden_heads(self, hidden_size, num_attention_heads, epsilon):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.layernorm_epsilon = epsilon
def set_attention(self, qkvw, qkvb, dense_w, dense_b): def set_attention(self, qkvw, qkvb, dense_w, dense_b):
self.qkvw = qkvw self.qkvw = qkvw
...@@ -148,6 +155,17 @@ class BaseTransformerContainer(ABC): ...@@ -148,6 +155,17 @@ class BaseTransformerContainer(ABC):
self.dense_w = dense_w self.dense_w = dense_w
self.dense_b = dense_b self.dense_b = dense_b
def set_lora_params(self, lora_params):
self.lora_params = lora_params
def set_q_k_v(self, qw, qb, kw, kb, vw, vb):
self.qw = qw
self.qb = qb
self.kw = kw
self.kb = kb
self.vw = vw
self.vb = vb
def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b): def set_mlp(self, _h4h_w, _h4h_b, _4hh_w, _4hh_b):
self._h4h_w = _h4h_w self._h4h_w = _h4h_w
self._h4h_b = _h4h_b self._h4h_b = _h4h_b
...@@ -168,63 +186,184 @@ class BaseTransformerContainer(ABC): ...@@ -168,63 +186,184 @@ class BaseTransformerContainer(ABC):
self.mlp_quantization() self.mlp_quantization()
def attention_quantization(self): def attention_quantization(self):
self.module.attention.attn_qkvw = self.quantizer.quantize( self.module.attention.attn_qkvw = self.quantizer.quantize(self.module.attention.attn_qkvw)
self.module.attention.attn_qkvw) self.module.attention.attn_ow = self.quantizer.quantize(self.module.attention.attn_ow)
self.module.attention.attn_ow = self.quantizer.quantize(
self.module.attention.attn_ow)
def mlp_quantization(self): def mlp_quantization(self):
self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w) self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w) self.module.mlp.output_w = self.quantizer.quantize(self.module.mlp.output_w)
def apply_tensor_parallelism(self, mp_replace): def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
reversed_dim = False
if mp_replace is None:
from deepspeed.module_inject import ReplaceWithTensorSlicing
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=tp_size, out_dim=0, in_dim=1)
reversed_dim = True
# setup the new Attention module # setup the new Attention module
self.attention_qkv_mp(mp_replace) if self.module.attention.attn_qkvw is None:
self.attention_o_mp(mp_replace) self.attention_q_k_v_mp(mp_replace, reversed_dim=reversed_dim)
else:
self.attention_qkv_mp(mp_replace, reversed_dim=reversed_dim)
self.attention_o_mp(mp_replace, reversed_dim=reversed_dim)
# setup the new MLP module # setup the new MLP module
self.mlp_inter_mp(mp_replace) self.mlp_inter_mp(mp_replace, reversed_dim=reversed_dim)
self.mlp_output_mp(mp_replace) self.mlp_output_mp(mp_replace, reversed_dim=reversed_dim)
# Apply weight quantization # Apply weight quantization
self.apply_weight_quantization() #self.apply_weight_quantization()
def attention_qkv_mp(self, mp_replace): def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.qkv_copy( if reversed_dim:
self.module.attention.attn_qkvw, self.module.attention.attn_qkvw = mp_replace.qkv_copy(
self.qkvw) self.module.attention.attn_qkvw[:self.qkvw.shape[0] // mp_replace.mp_size],
self.module.attention.attn_qkvb = mp_replace.qkv_copy( self.qkvw,
self.module.attention.attn_qkvb, int8=reversed_dim)
self.qkvb) self.module.attention.attn_qkvb = mp_replace.qkv_copy(
self.module.attention.attn_qkvb[:self.qkvw.shape[0] // mp_replace.mp_size],
def attention_o_mp(self, mp_replace): self.qkvb,
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, int8=reversed_dim)
self.dense_w) else:
self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw,
self.qkvw,
int8=reversed_dim)
self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb,
self.qkvb,
int8=reversed_dim)
def attention_q_k_v_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qw = mp_replace.copy(self.module.attention.attn_qw[:self.qw.shape[0] //
mp_replace.mp_size],
self.qw,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_kw = mp_replace.copy(self.module.attention.attn_kw[:self.qw.shape[0] //
mp_replace.mp_size],
self.kw,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_vw = mp_replace.copy(self.module.attention.attn_vw[:self.qw.shape[0] //
mp_replace.mp_size],
self.vw,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.attention.attn_qb = mp_replace.copy(
self.module.attention.attn_qb[:self.qw.shape[0] // mp_replace.mp_size],
self.qb,
int8=reversed_dim,
allocat_tensor=reversed_dim) if self.module.attention.attn_qb is not None else None
self.module.attention.attn_kb = mp_replace.copy(
self.module.attention.attn_kb[:self.qw.shape[0] // mp_replace.mp_size],
self.kb,
int8=reversed_dim,
allocat_tensor=reversed_dim) if self.module.attention.attn_kb is not None else None
self.module.attention.attn_vb = mp_replace.copy(
self.module.attention.attn_vb[:self.qw.shape[0] // mp_replace.mp_size],
self.vb,
int8=reversed_dim,
allocat_tensor=reversed_dim) if self.module.attention.attn_vb is not None else None
def attention_o_mp(self, mp_replace, reversed_dim=False):
if reversed_dim:
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow[:, :self.dense_w.shape[1] //
mp_replace.mp_size],
self.dense_w,
int8=reversed_dim,
allocat_tensor=reversed_dim)
else:
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow,
self.dense_w,
int8=reversed_dim)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
self.dense_b) self.dense_b,
int8=reversed_dim,
def mlp_inter_mp(self, mp_replace): allocat_tensor=reversed_dim)
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b) def mlp_inter_mp(self, mp_replace, reversed_dim=False):
if reversed_dim:
def mlp_output_mp(self, mp_replace): self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w[:self._h4h_w.shape[0] //
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w) mp_replace.mp_size],
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b, self._4hh_b) self._h4h_w,
int8=reversed_dim,
allocat_tensor=reversed_dim)
self.module.mlp.inter_b = mp_replace.copy(
self.module.mlp.inter_b[:self._h4h_w.shape[0] // mp_replace.mp_size],
self._h4h_b,
int8=reversed_dim,
allocat_tensor=reversed_dim) if self.module.mlp.inter_b is not None else None
else:
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w, int8=reversed_dim)
self.module.mlp.inter_b = mp_replace.copy(self.module.mlp.inter_b, self._h4h_b, int8=reversed_dim)
def mlp_output_mp(self, mp_replace, reversed_dim=False):
if reversed_dim:
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w[:, :self._4hh_w.shape[1] //
mp_replace.mp_size],
self._4hh_w,
int8=reversed_dim,
allocat_tensor=reversed_dim)
else:
self.module.mlp.output_w = mp_replace.copy(self.module.mlp.output_w, self._4hh_w, int8=reversed_dim)
self.module.mlp.output_b = mp_replace.copy(self.module.mlp.output_b,
self._4hh_b,
int8=reversed_dim,
allocat_tensor=reversed_dim)
def release_qkv(self):
del self.module.attention.attn_qkvw
del self.module.attention.attn_qkvb
self.module.attention.attn_qkvw = self.qkvw
self.module.attention.attn_qkvb = self.qkvb
if self.module.attention.attn_qw is not None:
qkv_data = [self.module.attention.attn_qw.data, \
self.module.attention.attn_qb.data if self.module.attention.attn_qb is not None else None, \
self.module.attention.attn_kw.data, \
self.module.attention.attn_kb.data if self.module.attention.attn_kb is not None else None, \
self.module.attention.attn_vw.data, \
self.module.attention.attn_vb.data if self.module.attention.attn_vb is not None else None]
for data in qkv_data:
del data
self.module.attention.attn_qw = self.qw
self.module.attention.attn_qb = self.qb
self.module.attention.attn_kw = self.kw
self.module.attention.attn_kb = self.kb
self.module.attention.attn_vw = self.vw
self.module.attention.attn_vb = self.vb
def release_memory(self):
self.release_qkv()
del self.module.attention.attn_ow
del self.module.attention.attn_ob
self.module.attention.attn_ow = self.dense_w
self.module.attention.attn_ob = self.dense_b
del self.module.mlp.inter_w
del self.module.mlp.inter_b
del self.module.mlp.output_w
del self.module.mlp.output_b
self.module.mlp.inter_w = self._h4h_w
self.module.mlp.inter_b = self._h4h_b
self.module.mlp.output_w = self._4hh_w
self.module.mlp.output_b = self._4hh_b
def copy_data_to_new_module(self): def copy_data_to_new_module(self):
if self.attn_nw is None: if self.attn_nw is None:
self.module.mlp.attn_nw = self.attn_nw self.module.mlp.attn_nw = self.attn_nw
self.module.mlp.attn_nb = self.attn_nb self.module.mlp.attn_nb = self.attn_nb
else: else:
self.module.mlp.attn_nw.data.copy_( self.module.mlp.attn_nw.data.copy_(self.attn_nw.to(get_accelerator().current_device_name()))
self.attn_nw.to(get_accelerator().current_device_name())) self.module.mlp.attn_nb.data.copy_(self.attn_nb.to(get_accelerator().current_device_name()))
self.module.mlp.attn_nb.data.copy_(
self.attn_nb.to(get_accelerator().current_device_name()))
self.module.norm_w.data.copy_( self.module.norm_w.data.copy_(self.input_nw.to(get_accelerator().current_device_name()))
self.input_nw.to(get_accelerator().current_device_name())) self.module.norm_b.data.copy_(self.input_nb.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(
self.input_nb.to(get_accelerator().current_device_name())) def align_merged_qkv(self):
if hasattr(self, '_align_merged_qkv'):
self._align_merged_qkv()
def partition_merged_qkv(self):
if hasattr(self, '_partition_merged_qkv'):
self._partition_merged_qkv()
def transpose(self): def transpose(self):
self.transpose_attention() self.transpose_attention()
...@@ -246,3 +385,110 @@ class BaseTransformerContainer(ABC): ...@@ -246,3 +385,110 @@ class BaseTransformerContainer(ABC):
data = data.reshape(data.shape[-1], data.shape[-2]) data = data.reshape(data.shape[-1], data.shape[-2])
data.to(get_accelerator().current_device_name()) data.to(get_accelerator().current_device_name())
return data return data
def reset_qkv_experimental(self):
if self.module.attention.attn_qkvw is None:
self.module.attention.attn_qkvw = torch.empty(self.qw.shape[0] * 3,
self.qw.shape[0],
dtype=self.qw.dtype,
device=self.qw.device)
self.module.attention.attn_qkvb = torch.empty(self.qw.shape[0] * 3,
dtype=self.qw.dtype,
device=self.qw.device)
self.module.attention.attn_qkvw.data[:self.qw.shape[0]] = self.qw.data
self.module.attention.attn_qkvb.data[:self.qw.shape[0]] = self.qb.data
self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
qkv_data = [self.qw.data, \
self.qb.data, \
self.kw.data, \
self.kb.data, \
self.vw.data, \
self.vb.data]
self.qw.data = self.module.attention.attn_qkvw.data[:self.qw.shape[0]]
self.qb.data = self.module.attention.attn_qkvb.data[:self.qw.shape[0]]
self.kw.data = self.module.attention.attn_qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.kb.data = self.module.attention.attn_qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vw.data = self.module.attention.attn_qkvw.data[2 * self.qw.shape[0]:]
self.vb.data = self.module.attention.attn_qkvb.data[2 * self.qw.shape[0]:]
for data in qkv_data:
del data
def reset_qkv(self):
self.qkvw.data[:self.qw.shape[0]] = self.qw.data
self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kw.data
self.qkvw.data[2 * self.qw.shape[0]:] = self.vw.data
if self.qkvb is not None:
self.qkvb.data[:self.qw.shape[0]] = self.qb.data
self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]] = self.kb.data
self.qkvb.data[2 * self.qw.shape[0]:] = self.vb.data
qkv_data = [self.qw.data, \
self.qb.data if self.qb is not None else None, \
self.kw.data, \
self.kb.data if self.kb is not None else None, \
self.vw.data, \
self.vb.data if self.vb is not None else None]
self.qw.data = self.qkvw.data[:self.qw.shape[0]]
self.kw.data = self.qkvw.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vw.data = self.qkvw.data[2 * self.qw.shape[0]:]
if self.qkvb is not None:
self.qb.data = self.qkvb.data[:self.qw.shape[0]]
self.kb.data = self.qkvb.data[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vb.data = self.qkvb.data[2 * self.qw.shape[0]:]
for data in qkv_data:
del data
def set_params_wo_copy(self, Z3_enabled=False):
self.module.mlp.attn_nw = self.attn_nw
self.module.mlp.attn_nb = self.attn_nb
self.module.norm_w = self.input_nw
self.module.norm_b = self.input_nb
self.module.mlp.inter_w = self._h4h_w
self.module.mlp.inter_b = self._h4h_b
self.module.mlp.output_w = self._4hh_w
self.module.mlp.output_b = self._4hh_b
self.module.attention.attn_ow = self.dense_w
self.module.attention.attn_ob = self.dense_b
if not Z3_enabled or self.q_k_v is None:
self.module.attention.attn_qkvw = self.qkvw
self.module.attention.attn_qkvb = self.qkvb
if self.q_k_v is not None:
if Z3_enabled:
self.module.attention.attn_qw = self.qw
self.module.attention.attn_qb = self.qb
self.module.attention.attn_kw = self.kw
self.module.attention.attn_kb = self.kb
self.module.attention.attn_vw = self.vw
self.module.attention.attn_vb = self.vb
else:
self.qw.data = self.qkvw[:self.qw.shape[0], :]
self.kw.data = self.qkvw[self.qw.shape[0]:2 * self.qw.shape[0], :]
self.vw.data = self.qkvw[self.qw.shape[0] * 2:, :]
if self.qkvb is not None:
self.qb.data = self.qkvb[:self.qw.shape[0]]
self.kb.data = self.qkvb[self.qw.shape[0]:2 * self.qw.shape[0]]
self.vb.data = self.qkvb[self.qw.shape[0] * 2:]
def get_lora_params(self):
return self.lora_params
def get_all_params(self):
if self.q_k_v is not None:
return [
self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w,
self._4hh_b, self.qw, self.qb, self.kw, self.kb, self.vw, self.vb, self.dense_w, self.dense_b
]
else:
return [
self.attn_nw, self.attn_nb, self.input_nw, self.input_nb, self._h4h_w, self._h4h_b, self._4hh_w,
self._4hh_b, self.qkvw, self.qkvb, self.dense_w, self.dense_b
]
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# Create a container object to save model-specific tensors using the policy file above. # Create a container object to save model-specific tensors using the policy file above.
from .base import * from .base import *
...@@ -8,6 +11,7 @@ from deepspeed.accelerator import get_accelerator ...@@ -8,6 +11,7 @@ from deepspeed.accelerator import get_accelerator
class BaseTransformerMoEContainer(BaseTransformerContainer): class BaseTransformerMoEContainer(BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Call the init function of the parent class to initialize the tensors and configs from parent class # Call the init function of the parent class to initialize the tensors and configs from parent class
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -16,9 +20,7 @@ class BaseTransformerMoEContainer(BaseTransformerContainer): ...@@ -16,9 +20,7 @@ class BaseTransformerMoEContainer(BaseTransformerContainer):
self.ep_world_size = dist.get_world_size() self.ep_world_size = dist.get_world_size()
self.local_ep_size = 1 if self.num_experts < self.ep_world_size else self.num_experts // self.ep_world_size self.local_ep_size = 1 if self.num_experts < self.ep_world_size else self.num_experts // self.ep_world_size
self.layer_norm_eps = self.config.layer_norm_eps if hasattr( self.layer_norm_eps = self.config.layer_norm_eps if hasattr(self.config, 'layer_norm_eps') else 1e-12,
self.config,
'layer_norm_eps') else 1e-12,
# MoE models will have a list of mlp related tensors # MoE models will have a list of mlp related tensors
self._h4h_w = [] self._h4h_w = []
...@@ -102,40 +104,27 @@ class BaseTransformerMoEContainer(BaseTransformerContainer): ...@@ -102,40 +104,27 @@ class BaseTransformerMoEContainer(BaseTransformerContainer):
gpu_index = dist.get_rank() gpu_index = dist.get_rank()
for ep_index in range(self.local_ep_size): for ep_index in range(self.local_ep_size):
# mlp inter # mlp inter
self.module.mlp[ep_index].inter_w.data = self._h4h_w[ self.module.mlp[ep_index].inter_w.data = self._h4h_w[gpu_index * self.local_ep_size + ep_index].to(
gpu_index * self.local_ep_size + ep_index].to( get_accelerator().current_device_name())
get_accelerator().current_device_name()) self.module.mlp[ep_index].inter_b.data = self._h4h_b[gpu_index * self.local_ep_size + ep_index].to(
self.module.mlp[ep_index].inter_b.data = self._h4h_b[ get_accelerator().current_device_name())
gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
# mlp output # mlp output
self.module.mlp[ep_index].output_w.data = self._4hh_w[ self.module.mlp[ep_index].output_w.data = self._4hh_w[gpu_index * self.local_ep_size + ep_index].to(
gpu_index * self.local_ep_size + ep_index].to( get_accelerator().current_device_name())
get_accelerator().current_device_name()) self.module.mlp[ep_index].output_b.data = self._4hh_b[gpu_index * self.local_ep_size + ep_index].to(
self.module.mlp[ep_index].output_b.data = self._4hh_b[ get_accelerator().current_device_name())
gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
def copy_data_to_new_module(self): def copy_data_to_new_module(self):
self.module.attn_nw.data = self.attn_nw.to( self.module.attn_nw.data = self.attn_nw.to(get_accelerator().current_device_name())
get_accelerator().current_device_name()) self.module.attn_nb.data = self.attn_nb.to(get_accelerator().current_device_name())
self.module.attn_nb.data = self.attn_nb.to(
get_accelerator().current_device_name())
self.module.norm_w.data.copy_( self.module.norm_w.data.copy_(self.input_nw.to(get_accelerator().current_device_name()))
self.input_nw.to(get_accelerator().current_device_name())) self.module.norm_b.data.copy_(self.input_nb.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(
self.input_nb.to(get_accelerator().current_device_name()))
if self.config.moe.type == 'residual': if self.config.moe.type == 'residual':
self.module.res_mlp.inter_w.data = self._res_h4h_w.to( self.module.res_mlp.inter_w.data = self._res_h4h_w.to(get_accelerator().current_device_name())
get_accelerator().current_device_name()) self.module.res_mlp.inter_b.data = self._res_h4h_b.to(get_accelerator().current_device_name())
self.module.res_mlp.inter_b.data = self._res_h4h_b.to( self.module.res_mlp.output_w.data = self._res_4hh_w.to(get_accelerator().current_device_name())
get_accelerator().current_device_name()) self.module.res_mlp.output_b.data = self._res_4hh_b.to(get_accelerator().current_device_name())
self.module.res_mlp.output_w.data = self._res_4hh_w.to( self.module.res_coef.data = self._res_coef.to(get_accelerator().current_device_name())
get_accelerator().current_device_name())
self.module.res_mlp.output_b.data = self._res_4hh_b.to(
get_accelerator().current_device_name())
self.module.res_coef.data = self._res_coef.to(
get_accelerator().current_device_name())
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
...@@ -8,6 +11,7 @@ from ..policy import TransformerPolicy ...@@ -8,6 +11,7 @@ from ..policy import TransformerPolicy
class DS_BERTContainer(BaseTransformerContainer): class DS_BERTContainer(BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -23,6 +27,7 @@ class DS_BERTContainer(BaseTransformerContainer): ...@@ -23,6 +27,7 @@ class DS_BERTContainer(BaseTransformerContainer):
class HFBertLayerPolicy(TransformerPolicy): class HFBertLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False): def __init__(self, client_module, inference=False):
super().__init__(inference, pre_attn_norm=False) super().__init__(inference, pre_attn_norm=False)
self.client_module = client_module self.client_module = client_module
...@@ -39,10 +44,18 @@ class HFBertLayerPolicy(TransformerPolicy): ...@@ -39,10 +44,18 @@ class HFBertLayerPolicy(TransformerPolicy):
HFBertLayerPolicy._orig_layer_class = None HFBertLayerPolicy._orig_layer_class = None
def get_hidden_heads(self): def get_hidden_heads(self):
if self.pre_attn_norm:
attention_layernorm = self.client_module.PostAttentionLayerNorm
else:
attention_layernorm = self.client_module.attention.output.LayerNorm
return self.client_module.attention.self.query.weight.shape[1], \ return self.client_module.attention.self.query.weight.shape[1], \
self.client_module.attention.self.num_attention_heads self.client_module.attention.self.num_attention_heads, \
attention_layernorm.eps
def attention(self): def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attention.self.query.weight qw = self.client_module.attention.self.query.weight
qb = self.client_module.attention.self.query.bias qb = self.client_module.attention.self.query.bias
kw = self.client_module.attention.self.key.weight kw = self.client_module.attention.self.key.weight
...@@ -50,8 +63,8 @@ class HFBertLayerPolicy(TransformerPolicy): ...@@ -50,8 +63,8 @@ class HFBertLayerPolicy(TransformerPolicy):
vw = self.client_module.attention.self.value.weight vw = self.client_module.attention.self.value.weight
vb = self.client_module.attention.self.value.bias vb = self.client_module.attention.self.value.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
return qkvw, \ return qkvw, \
qkvb, \ qkvb, \
...@@ -79,3 +92,6 @@ class HFBertLayerPolicy(TransformerPolicy): ...@@ -79,3 +92,6 @@ class HFBertLayerPolicy(TransformerPolicy):
attention_layernorm.bias, \ attention_layernorm.bias, \
transformer_layernorm.weight, \ transformer_layernorm.weight, \
transformer_layernorm.bias transformer_layernorm.bias
def get_lora_params(self):
return []
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .features.meta_tensor import MetaTensorContainer from .features.meta_tensor import MetaTensorContainer
...@@ -7,10 +10,13 @@ from ..policy import TransformerPolicy ...@@ -7,10 +10,13 @@ from ..policy import TransformerPolicy
from ..policy import transformer_param_names from ..policy import transformer_param_names
from ..policy import maybe_copy from ..policy import maybe_copy
from ..policy import maybe_get_lora
supported_models = {None} supported_models = {None}
class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer): class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -24,13 +30,9 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer): ...@@ -24,13 +30,9 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
self.module.config.scale_attention = self.scale_attention self.module.config.scale_attention = self.scale_attention
return self.module return self.module
def attention_qkv_mp(self, mp_replace): def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.copy( self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvw, self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
self.qkvw)
self.module.attention.attn_qkvb = mp_replace.copy(
self.module.attention.attn_qkvb,
self.qkvb)
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix): def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = ( param_names = (
...@@ -58,58 +60,39 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer): ...@@ -58,58 +60,39 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
megatron_v2=self.policy.is_megatron_v2, megatron_v2=self.policy.is_megatron_v2,
split_qkv=self.policy.split_qkv) split_qkv=self.policy.split_qkv)
for i in range(2, 4): for i in range(2, 4):
maybe_copy(module.attention, maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i]) prefix + param_names[i])
for i in range(4, 10): for i in range(4, 10):
maybe_copy(module.mlp, maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i]) prefix + param_names[i])
for i in range(10, 12): for i in range(10, 12):
maybe_copy(module, maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
class BLOOMLayerPolicy(TransformerPolicy): class BLOOMLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
def __init__(self, def __init__(self, client_module, inference=True, use_load_prefix=True, split_qkv=False):
client_module, super().__init__(inference, linear_layer=True, use_load_prefix=use_load_prefix, split_qkv=split_qkv)
inference=True,
use_load_prefix=True,
split_qkv=False):
super().__init__(inference,
linear_layer=True,
use_load_prefix=use_load_prefix,
split_qkv=split_qkv)
self.client_module = client_module self.client_module = client_module
try: try:
import transformers import transformers
BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
global supported_models global supported_models
supported_models.update( supported_models.update({transformers.models.bloom.modeling_bloom.BloomModel})
{transformers.models.bloom.modeling_bloom.BloomModel})
except Exception as e: except Exception as e:
print( print(f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}")
f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}"
)
BLOOMLayerPolicy._orig_layer_class = None BLOOMLayerPolicy._orig_layer_class = None
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.self_attention.hidden_size, \ return self.client_module.self_attention.hidden_size, \
self.client_module.self_attention.num_heads self.client_module.self_attention.num_heads, \
self.client_module.input_layernorm.eps
def attention(self): def get_q_k_v(self):
return None
def attention(self, enable_training=False):
return self.client_module.self_attention.query_key_value.weight, \ return self.client_module.self_attention.query_key_value.weight, \
self.client_module.self_attention.query_key_value.bias, \ self.client_module.self_attention.query_key_value.bias, \
self.client_module.self_attention.dense.weight, \ self.client_module.self_attention.dense.weight, \
...@@ -126,3 +109,14 @@ class BLOOMLayerPolicy(TransformerPolicy): ...@@ -126,3 +109,14 @@ class BLOOMLayerPolicy(TransformerPolicy):
self.client_module.post_attention_layernorm.bias, \ self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias self.client_module.input_layernorm.bias
def get_lora_params(self):
all_lora_params = []
for p in [
self.client_module.mlp.dense_h_to_4h, \
self.client_module.mlp.dense_4h_to_h, \
self.client_module.self_attention.query_key_value, \
self.client_module.self_attention.dense
]:
all_lora_params.append(maybe_get_lora(p))
return all_lora_params
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
...@@ -8,6 +11,7 @@ from ..policy import TransformerPolicy ...@@ -8,6 +11,7 @@ from ..policy import TransformerPolicy
class DS_CLIPContainer(BaseTransformerContainer): class DS_CLIPContainer(BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -21,6 +25,7 @@ class DS_CLIPContainer(BaseTransformerContainer): ...@@ -21,6 +25,7 @@ class DS_CLIPContainer(BaseTransformerContainer):
class HFCLIPLayerPolicy(TransformerPolicy): class HFCLIPLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False): def __init__(self, client_module, inference=False):
super().__init__(inference, pre_attn_norm=True, scale_attention=True) super().__init__(inference, pre_attn_norm=True, scale_attention=True)
self.client_module = client_module self.client_module = client_module
...@@ -35,7 +40,11 @@ class HFCLIPLayerPolicy(TransformerPolicy): ...@@ -35,7 +40,11 @@ class HFCLIPLayerPolicy(TransformerPolicy):
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.self_attn.q_proj.weight.shape[1], \ return self.client_module.self_attn.q_proj.weight.shape[1], \
self.client_module.self_attn.num_heads self.client_module.self_attn.num_heads, \
self.client_module.layer_norm1.eps
def get_q_k_v(self):
return None
def attention(self): def attention(self):
qw = self.client_module.self_attn.q_proj.weight qw = self.client_module.self_attn.q_proj.weight
...@@ -64,3 +73,6 @@ class HFCLIPLayerPolicy(TransformerPolicy): ...@@ -64,3 +73,6 @@ class HFCLIPLayerPolicy(TransformerPolicy):
self.client_module.layer_norm2.bias, \ self.client_module.layer_norm2.bias, \
self.client_module.layer_norm1.weight, \ self.client_module.layer_norm1.weight, \
self.client_module.layer_norm1.bias self.client_module.layer_norm1.bias
def get_lora_params(self):
return []
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
...@@ -8,6 +11,7 @@ from ..policy import TransformerPolicy ...@@ -8,6 +11,7 @@ from ..policy import TransformerPolicy
class DS_DistilBERTContainer(BaseTransformerContainer): class DS_DistilBERTContainer(BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -41,9 +45,13 @@ class HFDistilBertLayerPolicy(TransformerPolicy): ...@@ -41,9 +45,13 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.attention.q_lin.weight.shape[1], \ return self.client_module.attention.q_lin.weight.shape[1], \
self.client_module.attention.n_heads self.client_module.attention.n_heads, \
self.client_module.sa_layer_norm.eps
def attention(self): def get_q_k_v(self):
return None
def attention(self, enable_training=False):
qw = self.client_module.attention.q_lin.weight qw = self.client_module.attention.q_lin.weight
qb = self.client_module.attention.q_lin.bias qb = self.client_module.attention.q_lin.bias
kw = self.client_module.attention.k_lin.weight kw = self.client_module.attention.k_lin.weight
...@@ -51,8 +59,8 @@ class HFDistilBertLayerPolicy(TransformerPolicy): ...@@ -51,8 +59,8 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
vw = self.client_module.attention.v_lin.weight vw = self.client_module.attention.v_lin.weight
vb = self.client_module.attention.v_lin.bias vb = self.client_module.attention.v_lin.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0)) qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0)) qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
return qkvw, \ return qkvw, \
qkvb, \ qkvb, \
...@@ -73,3 +81,6 @@ class HFDistilBertLayerPolicy(TransformerPolicy): ...@@ -73,3 +81,6 @@ class HFDistilBertLayerPolicy(TransformerPolicy):
attention_layernorm.bias, \ attention_layernorm.bias, \
transformer_layernorm.weight, \ transformer_layernorm.weight, \
transformer_layernorm.bias transformer_layernorm.bias
def get_lora_params(self):
return []
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .megatron import MegatronContainer from .megatron import MegatronContainer
from .meta_tensor import MetaTensorContainer from .meta_tensor import MetaTensorContainer
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from abc import ABC from abc import ABC
class MegatronContainer(ABC): class MegatronContainer(ABC):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.megatron_v2 = self.policy.is_megatron_v2 self.megatron_v2 = self.policy.is_megatron_v2
def transpose_qkv_alignment(self, x): def _align_qkv_transposed(self, x):
attention_head_size = x.shape[-1] // self.num_attention_heads attention_head_size = x.shape[-1] // self.num_attention_heads
new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
x_1 = x.view(*new_x_shape) x_1 = x.view(*new_x_shape)
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
if len(q.shape) > 2: if len(q.shape) > 2:
return torch.cat((q.reshape(q.shape[0], return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)),
-1),
k.reshape(q.shape[0],
-1),
v.reshape(q.shape[0],
-1)),
dim=-1).reshape(x.shape) dim=-1).reshape(x.shape)
else: else:
return torch.cat((q.reshape(-1), return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
k.reshape(-1),
v.reshape(-1)), def _align_qkv(self, x):
dim=-1).reshape(x.shape) attention_head_size = x.shape[0] // self.num_attention_heads
new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:]
x_1 = x.view(*new_x_shape)
div_dim = len(x_1.size()) - 2 if len(x.shape) == 2 else -1
(q, k, v) = torch.split(x_1, (x_1.shape[div_dim] // 3), dim=div_dim)
if len(q.shape) > 2:
x.data.copy_(
torch.cat((q.reshape(-1, q.shape[-1]), k.reshape(-1, q.shape[-1]), v.reshape(-1, q.shape[-1])),
dim=0).reshape(x.shape))
else:
x.data.copy_(torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape))
def _align_merged_qkv(self):
if hasattr(self.qkvw, 'ds_id'):
from deepspeed.runtime.zero import GatheredParameters
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
param_list = [self.qkvw, self.qkvb]
non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \
param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
with GatheredParameters(non_active_params):
self._align_qkv(self.qkvw)
self._align_qkv(self.qkvb)
else:
self._align_qkv(self.qkvw)
self._align_qkv(self.qkvb)
def _partition_qkv(self, x):
q_k_v = torch.split(x, (x.shape[0] // 3), dim=0)
attention_head_size = q_k_v[0].shape[0] // self.num_attention_heads
new_x_shape = (self.num_attention_heads, attention_head_size) + x.size()[1:]
q, k, v = [data.view(*new_x_shape) for data in q_k_v]
if len(q.shape) > 2:
x.data.copy_(torch.cat((q, k, v), dim=-2).reshape(-1, q.shape[-1]))
else:
x.data.copy_(torch.cat((q, k, v), dim=-1).reshape(-1))
def _partition_merged_qkv(self):
if hasattr(self.qkvw, 'ds_id'):
from deepspeed.runtime.zero import GatheredParameters
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
param_list = [self.qkvw, self.qkvb]
non_active_params = [param for param in param_list if (hasattr(param, 'ds_id') and \
param.ds_status == ZeroParamStatus.NOT_AVAILABLE)]
with GatheredParameters(non_active_params):
self._partition_qkv(self.qkvw)
self._partition_qkv(self.qkvb)
else:
self._partition_qkv(self.qkvw)
self._partition_qkv(self.qkvb)
def transpose(self): def transpose(self):
super().transpose() super().transpose()
if self.megatron_v2: if self.megatron_v2:
self.qkvw = torch.nn.parameter.Parameter( self.qkvw = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvw).contiguous())
self.transpose_qkv_alignment(self.qkvw).contiguous()) self.qkvb = torch.nn.parameter.Parameter(self._align_qkv_transposed(self.qkvb).contiguous())
self.qkvb = torch.nn.parameter.Parameter(
self.transpose_qkv_alignment(self.qkvb).contiguous())
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class MetaTensorContainer(ABC): class MetaTensorContainer(ABC):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.is_meta = False self.is_meta = False
self.ckpt_load_enabled = True self.ckpt_load_enabled = True
def initialize_tensors(self): def initialize_tensors(self, enable_training=False):
super().initialize_tensors() super().initialize_tensors(enable_training=enable_training)
self.is_meta = self.qkvw.is_meta self.is_meta = self.qkvw.is_meta
def apply_tensor_parallelism(self, mp_replace): def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
if self.is_meta: if self.is_meta:
if self.qkvb is None: if self.qkvb is None:
self.module.attention.attn_qkvb = None self.module.attention.attn_qkvb = None
if self.dense_b is None: if self.dense_b is None:
self.module.attention.attn_ob = None self.module.attention.attn_ob = None
else: else:
super().apply_tensor_parallelism(mp_replace) super().apply_tensor_parallelism(mp_replace, mp_group, tp_size)
def copy_data_to_new_module(self): def copy_data_to_new_module(self):
if self.is_meta: if self.is_meta:
...@@ -53,6 +57,5 @@ class MetaTensorContainer(ABC): ...@@ -53,6 +57,5 @@ class MetaTensorContainer(ABC):
of q, k, and v are stored together and needs to split in the of q, k, and v are stored together and needs to split in the
DeepSpeed-Inference API. DeepSpeed-Inference API.
""" """
raise NotImplementedError( raise NotImplementedError("A load_params() function must be defined in the model container \
"A load_params() function must be defined in the model container \
when inheriting the MetaTensorContainer feature") when inheriting the MetaTensorContainer feature")
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
...@@ -6,6 +9,7 @@ from ..policy import TransformerPolicy ...@@ -6,6 +9,7 @@ from ..policy import TransformerPolicy
class DS_GPT2Container(BaseTransformerContainer): class DS_GPT2Container(BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -33,9 +37,13 @@ class HFGPT2LayerPolicy(TransformerPolicy): ...@@ -33,9 +37,13 @@ class HFGPT2LayerPolicy(TransformerPolicy):
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.attn.embed_dim, \ return self.client_module.attn.embed_dim, \
self.client_module.attn.num_heads self.client_module.attn.num_heads, \
self.client_module.ln_1.eps
def attention(self): def get_q_k_v(self):
return None
def attention(self, enable_training=False):
return self.client_module.attn.c_attn.weight, \ return self.client_module.attn.c_attn.weight, \
self.client_module.attn.c_attn.bias, \ self.client_module.attn.c_attn.bias, \
self.client_module.attn.c_proj.weight, \ self.client_module.attn.c_proj.weight, \
...@@ -52,3 +60,6 @@ class HFGPT2LayerPolicy(TransformerPolicy): ...@@ -52,3 +60,6 @@ class HFGPT2LayerPolicy(TransformerPolicy):
self.client_module.ln_2.bias, \ self.client_module.ln_2.bias, \
self.client_module.ln_1.weight, \ self.client_module.ln_1.weight, \
self.client_module.ln_1.bias self.client_module.ln_1.bias
def get_lora_params(self):
return []
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .features.meta_tensor import MetaTensorContainer from .features.meta_tensor import MetaTensorContainer
...@@ -10,8 +13,11 @@ from ..policy import transformer_param_names ...@@ -10,8 +13,11 @@ from ..policy import transformer_param_names
from ..policy import maybe_copy from ..policy import maybe_copy
from ..policy import maybe_copy_qkv from ..policy import maybe_copy_qkv
from ..policy import maybe_get_lora
class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer): class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -36,36 +42,20 @@ class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer): ...@@ -36,36 +42,20 @@ class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
'ln_1.weight', \ 'ln_1.weight', \
'ln_1.bias' 'ln_1.bias'
) )
maybe_copy_qkv( maybe_copy_qkv(module.attention,
module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw',
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention,
sd, sd,
weight_quantizer, weight_quantizer,
mp_replace, mp_replace,
transformer_param_names[i - 1], 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i]) prefix + param_names[i])
for i in range(4, 8): for i in range(4, 8):
maybe_copy(module.mlp, maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i]) prefix + param_names[i])
for i in range(8, 10): for i in range(8, 10):
maybe_copy(module, maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i + 2],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i + 2],
prefix + param_names[i]) prefix + param_names[i])
...@@ -82,15 +72,24 @@ class HFGPTJLayerPolicy(TransformerPolicy): ...@@ -82,15 +72,24 @@ class HFGPTJLayerPolicy(TransformerPolicy):
HFGPTJLayerPolicy._orig_layer_class = None HFGPTJLayerPolicy._orig_layer_class = None
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \ return self.client_module.attn.embed_dim, \
self.client_module.attn.num_attention_heads self.client_module.attn.num_attention_heads, \
self.client_module.ln_1.eps
def attention(self): def get_q_k_v(self):
return self.client_module.attn.q_proj.weight, \
None, \
self.client_module.attn.k_proj.weight, \
None, \
self.client_module.attn.v_proj.weight, \
None
def attention(self, enable_training=False):
qw = self.client_module.attn.q_proj.weight qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight vw = self.client_module.attn.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
return qkvw, \ return qkvw, \
None, \ None, \
...@@ -108,3 +107,16 @@ class HFGPTJLayerPolicy(TransformerPolicy): ...@@ -108,3 +107,16 @@ class HFGPTJLayerPolicy(TransformerPolicy):
None, \ None, \
self.client_module.ln_1.weight, \ self.client_module.ln_1.weight, \
self.client_module.ln_1.bias self.client_module.ln_1.bias
def get_lora_params(self):
all_lora_params = []
for p in [
self.client_module.mlp.fc_in, \
self.client_module.mlp.fc_out, \
self.client_module.attn.q_proj, \
self.client_module.attn.k_proj, \
self.client_module.attn.v_proj, \
self.client_module.attn.out_proj, \
]:
all_lora_params.append(maybe_get_lora(p))
return all_lora_params
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .features.meta_tensor import MetaTensorContainer from .features.meta_tensor import MetaTensorContainer
...@@ -10,8 +13,11 @@ from ..policy import transformer_param_names ...@@ -10,8 +13,11 @@ from ..policy import transformer_param_names
from ..policy import maybe_copy from ..policy import maybe_copy
from ..policy import maybe_copy_qkv from ..policy import maybe_copy_qkv
from ..policy import maybe_get_lora
class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer): class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -39,40 +45,25 @@ class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer): ...@@ -39,40 +45,25 @@ class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
'ln_1.weight', \ 'ln_1.weight', \
'ln_1.bias' 'ln_1.bias'
) )
maybe_copy_qkv( maybe_copy_qkv(module.attention,
module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw',
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 5):
maybe_copy(module.attention,
sd, sd,
weight_quantizer, weight_quantizer,
mp_replace, mp_replace,
transformer_param_names[i - 1], 'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 5):
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i]) prefix + param_names[i])
for i in range(5, 11): for i in range(5, 11):
maybe_copy(module.mlp, maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
prefix + param_names[i]) prefix + param_names[i])
for i in range(11, 13): for i in range(11, 13):
maybe_copy(module, maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
prefix + param_names[i]) prefix + param_names[i])
class HFGPTNEOLayerPolicy(TransformerPolicy): class HFGPTNEOLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=True): def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=False) super().__init__(inference, scale_attention=False)
self.client_module = client_module self.client_module = client_module
...@@ -83,15 +74,24 @@ class HFGPTNEOLayerPolicy(TransformerPolicy): ...@@ -83,15 +74,24 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
HFGPTNEOLayerPolicy._orig_layer_class = None HFGPTNEOLayerPolicy._orig_layer_class = None
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.attn.attention.q_proj.weight.shape[1], \ return self.client_module.attn.attention.embed_dim, \
self.client_module.attn.attention.num_heads self.client_module.attn.attention.num_heads, \
self.client_module.ln_1.eps
def attention(self): def get_q_k_v(self):
return self.client_module.attn.attention.q_proj.weight, \
None, \
self.client_module.attn.attention.k_proj.weight, \
None, \
self.client_module.attn.attention.v_proj.weight, \
None
def attention(self, enable_training=False):
qw = self.client_module.attn.attention.q_proj.weight qw = self.client_module.attn.attention.q_proj.weight
kw = self.client_module.attn.attention.k_proj.weight kw = self.client_module.attn.attention.k_proj.weight
vw = self.client_module.attn.attention.v_proj.weight vw = self.client_module.attn.attention.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
return qkvw, \ return qkvw, \
None, \ None, \
...@@ -109,3 +109,16 @@ class HFGPTNEOLayerPolicy(TransformerPolicy): ...@@ -109,3 +109,16 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
self.client_module.ln_2.bias, \ self.client_module.ln_2.bias, \
self.client_module.ln_1.weight, \ self.client_module.ln_1.weight, \
self.client_module.ln_1.bias self.client_module.ln_1.bias
def get_lora_params(self):
all_lora_params = []
for p in [
self.client_module.mlp.c_fc, \
self.client_module.mlp.c_proj, \
self.client_module.attn.attention.q_proj, \
self.client_module.attn.attention.k_proj, \
self.client_module.attn.attention.v_proj, \
self.client_module.attn.attention.out_proj, \
]:
all_lora_params.append(maybe_get_lora(p))
return all_lora_params
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .features.meta_tensor import MetaTensorContainer from .features.meta_tensor import MetaTensorContainer
...@@ -10,10 +13,11 @@ from ..policy import transformer_param_names ...@@ -10,10 +13,11 @@ from ..policy import transformer_param_names
from ..policy import maybe_copy from ..policy import maybe_copy
from packaging import version as pkg_version from packaging import version as pkg_version
from ..policy import maybe_get_lora
class DS_GPTNEOXContainer(MetaTensorContainer, MegatronContainer, BaseTransformerContainer):
class DS_GPTNEOXContainer(MetaTensorContainer,
MegatronContainer,
BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -57,26 +61,13 @@ class DS_GPTNEOXContainer(MetaTensorContainer, ...@@ -57,26 +61,13 @@ class DS_GPTNEOXContainer(MetaTensorContainer,
split_qkv=self.policy.split_qkv, split_qkv=self.policy.split_qkv,
heads=self.policy.client_module.attention.num_attention_heads) heads=self.policy.client_module.attention.num_attention_heads)
for i in range(2, 4): for i in range(2, 4):
maybe_copy(module.attention, maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i]) prefix + param_names[i])
for i in range(4, 10): for i in range(4, 10):
maybe_copy(module.mlp, maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i]) prefix + param_names[i])
for i in range(10, 12): for i in range(10, 12):
maybe_copy(module, maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
class GPTNEOXLayerPolicy(TransformerPolicy): class GPTNEOXLayerPolicy(TransformerPolicy):
...@@ -102,10 +93,14 @@ class GPTNEOXLayerPolicy(TransformerPolicy): ...@@ -102,10 +93,14 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
else: else:
attention = self.client_module.self_attention attention = self.client_module.self_attention
return self.client_module.attention.query_key_value.weight.shape[1], \ return self.client_module.attention.hidden_size, \
self.client_module.attention.num_attention_heads self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps
def get_q_k_v(self):
return None
def attention(self): def attention(self, enable_training=False):
if GPTNEOXLayerPolicy.version == 0: if GPTNEOXLayerPolicy.version == 0:
attention = self.client_module.attention attention = self.client_module.attention
else: else:
...@@ -127,3 +122,19 @@ class GPTNEOXLayerPolicy(TransformerPolicy): ...@@ -127,3 +122,19 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
self.client_module.post_attention_layernorm.bias, \ self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias self.client_module.input_layernorm.bias
def get_lora_params(self):
if GPTNEOXLayerPolicy.version == 0:
attention = self.client_module.attention
else:
attention = self.client_module.self_attention
all_lora_params = []
for p in [
self.client_module.mlp.dense_h_to_4h, \
self.client_module.mlp.dense_4h_to_h, \
attention.query_key_value, \
attention.dense
]:
all_lora_params.append(maybe_get_lora(p))
return all_lora_params
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .features.megatron import MegatronContainer from .features.megatron import MegatronContainer
...@@ -9,6 +12,7 @@ from packaging import version as pkg_version ...@@ -9,6 +12,7 @@ from packaging import version as pkg_version
class DS_MegatronGPTContainer(MegatronContainer, BaseTransformerContainer): class DS_MegatronGPTContainer(MegatronContainer, BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -36,9 +40,7 @@ class MegatronLayerPolicy(TransformerPolicy): ...@@ -36,9 +40,7 @@ class MegatronLayerPolicy(TransformerPolicy):
use_mup = False use_mup = False
def __init__(self, client_module, inference=True): def __init__(self, client_module, inference=True):
super().__init__(inference, super().__init__(inference, megatron_v2=MegatronLayerPolicy.megatron_v2, use_mup=MegatronLayerPolicy.use_mup)
megatron_v2=MegatronLayerPolicy.megatron_v2,
use_mup=MegatronLayerPolicy.use_mup)
self.client_module = client_module self.client_module = client_module
# we use megatron version to differentiate between the old and new # we use megatron version to differentiate between the old and new
# megatron-lm source code # megatron-lm source code
...@@ -54,9 +56,13 @@ class MegatronLayerPolicy(TransformerPolicy): ...@@ -54,9 +56,13 @@ class MegatronLayerPolicy(TransformerPolicy):
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.attention.query_key_value.weight.shape[1], \ return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads self.client_module.attention.num_attention_heads, \
self.client_module.input_layernorm.eps
def attention(self): def get_q_k_v(self):
return None
def attention(self, enable_training=False):
if self.inference: if self.inference:
if MegatronLayerPolicy.version == 0: if MegatronLayerPolicy.version == 0:
attention = self.client_module.attention attention = self.client_module.attention
...@@ -104,3 +110,6 @@ class MegatronLayerPolicy(TransformerPolicy): ...@@ -104,3 +110,6 @@ class MegatronLayerPolicy(TransformerPolicy):
self.client_module.post_attention_layernorm.bias, \ self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias self.client_module.input_layernorm.bias
def get_lora_params(self):
return []
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .base_moe import * from .base_moe import *
...@@ -10,6 +13,7 @@ from packaging import version as pkg_version ...@@ -10,6 +13,7 @@ from packaging import version as pkg_version
class DS_MegatronGPTMoEContainer(MegatronContainer, BaseTransformerMoEContainer): class DS_MegatronGPTMoEContainer(MegatronContainer, BaseTransformerMoEContainer):
def __init__(self, policy, config, model_config, layer_id): def __init__(self, policy, config, model_config, layer_id):
super().__init__(policy, config, model_config, layer_id) super().__init__(policy, config, model_config, layer_id)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import * from .base import *
from .features.meta_tensor import MetaTensorContainer from .features.meta_tensor import MetaTensorContainer
...@@ -9,10 +12,12 @@ from ..policy import TransformerPolicy ...@@ -9,10 +12,12 @@ from ..policy import TransformerPolicy
from ..policy import transformer_param_names from ..policy import transformer_param_names
from ..policy import maybe_copy from ..policy import maybe_copy
from ..policy import maybe_copy_qkv from ..policy import maybe_copy_qkv
from ..policy import maybe_get_lora
from deepspeed.utils.types import ActivationFuncType from deepspeed.utils.types import ActivationFuncType
class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer): class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -50,32 +55,16 @@ class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer): ...@@ -50,32 +55,16 @@ class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer):
weight_quantizer, weight_quantizer,
mp_replace, mp_replace,
transformer_param_names[i // 3], transformer_param_names[i // 3],
[ [prefix + param_names[i], prefix + param_names[i + 1], prefix + param_names[i + 2]],
prefix + param_names[i],
prefix + param_names[i + 1],
prefix + param_names[i + 2]
],
split_qkv=self.policy.split_qkv) split_qkv=self.policy.split_qkv)
for i in range(6, 8): for i in range(6, 8):
maybe_copy(module.attention, maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
prefix + param_names[i]) prefix + param_names[i])
for i in range(8, 14): for i in range(8, 14):
maybe_copy(module.mlp, maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
prefix + param_names[i]) prefix + param_names[i])
for i in range(14, 16): for i in range(14, 16):
maybe_copy(module, maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
prefix + param_names[i]) prefix + param_names[i])
...@@ -83,27 +72,40 @@ class HFOPTLayerPolicy(TransformerPolicy): ...@@ -83,27 +72,40 @@ class HFOPTLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
def __init__(self, client_module, inference=True, use_load_prefix=True): def __init__(self, client_module, inference=True, use_load_prefix=True):
super().__init__(inference, super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix)
linear_layer=True,
mlp_act_func_type=ActivationFuncType.ReLU,
pre_attn_norm=True,
use_load_prefix=use_load_prefix)
self.client_module = client_module self.client_module = client_module
try: try:
import transformers import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
if isinstance(TransformerPolicy.hf_model_config,
transformers.models.opt.configuration_opt.OPTConfig):
self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before
except: except:
HFOPTLayerPolicy._orig_layer_class = None HFOPTLayerPolicy._orig_layer_class = None
if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config,
"activation_function"):
if TransformerPolicy.hf_model_config.activation_function == "relu":
self.mlp_act_func_type = ActivationFuncType.ReLU
elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]:
self.mlp_act_func_type = ActivationFuncType.GELU
else:
raise ValueError("Unsupported activation function: {}".format(
TransformerPolicy.hf_model_config.activation_function))
else:
self.mlp_act_func_type = ActivationFuncType.ReLU # default
def get_hidden_heads(self): def get_hidden_heads(self):
return self.client_module.self_attn.embed_dim, \ return self.client_module.self_attn.embed_dim, \
self.client_module.self_attn.num_heads self.client_module.self_attn.num_heads, \
self.client_module.self_attn_layer_norm.eps
def attention(self):
def get_q_k_v(self):
return self.client_module.self_attn.q_proj.weight, \
self.client_module.self_attn.q_proj.bias, \
self.client_module.self_attn.k_proj.weight, \
self.client_module.self_attn.k_proj.bias, \
self.client_module.self_attn.v_proj.weight, \
self.client_module.self_attn.v_proj.bias
def attention(self, enable_training=False):
qw = self.client_module.self_attn.q_proj.weight qw = self.client_module.self_attn.q_proj.weight
qb = self.client_module.self_attn.q_proj.bias qb = self.client_module.self_attn.q_proj.bias
...@@ -113,9 +115,8 @@ class HFOPTLayerPolicy(TransformerPolicy): ...@@ -113,9 +115,8 @@ class HFOPTLayerPolicy(TransformerPolicy):
vw = self.client_module.self_attn.v_proj.weight vw = self.client_module.self_attn.v_proj.weight
vb = self.client_module.self_attn.v_proj.bias vb = self.client_module.self_attn.v_proj.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training)
return qkvw, \ return qkvw, \
qkvb, \ qkvb, \
self.client_module.self_attn.out_proj.weight, \ self.client_module.self_attn.out_proj.weight, \
...@@ -132,3 +133,16 @@ class HFOPTLayerPolicy(TransformerPolicy): ...@@ -132,3 +133,16 @@ class HFOPTLayerPolicy(TransformerPolicy):
self.client_module.final_layer_norm.bias, \ self.client_module.final_layer_norm.bias, \
self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.weight, \
self.client_module.self_attn_layer_norm.bias self.client_module.self_attn_layer_norm.bias
def get_lora_params(self):
all_lora_params = []
for p in [
self.client_module.fc1, \
self.client_module.fc2, \
self.client_module.self_attn.q_proj, \
self.client_module.self_attn.k_proj, \
self.client_module.self_attn.v_proj, \
self.client_module.self_attn.out_proj, \
]:
all_lora_params.append(maybe_get_lora(p))
return all_lora_params
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -9,6 +11,7 @@ from ...model_implementations.diffusers.unet import DSUNet ...@@ -9,6 +11,7 @@ from ...model_implementations.diffusers.unet import DSUNet
class UNetPolicy(DSPolicy): class UNetPolicy(DSPolicy):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
try: try:
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from ..policy import DSPolicy from ..policy import DSPolicy
from ...model_implementations.diffusers.vae import DSVAE from ...model_implementations.diffusers.vae import DSVAE
class VAEPolicy(DSPolicy): class VAEPolicy(DSPolicy):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
try: try:
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import copy import copy
import torch import torch
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
def module_inject(layer_obj, def module_inject(layer_obj, model, config, micro_batch_size, max_seq_length, seed, preln, fp16=True):
model,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16=True):
for name, child in model.named_children(): for name, child in model.named_children():
if isinstance(child, layer_obj): if isinstance(child, layer_obj):
print('REPLACING BertLayer') print('REPLACING BertLayer')
cuda_config = DeepSpeedTransformerConfig( cuda_config = DeepSpeedTransformerConfig(batch_size=micro_batch_size,
batch_size=micro_batch_size, max_seq_length=max_seq_length,
max_seq_length=max_seq_length, hidden_size=config.hidden_size,
hidden_size=config.hidden_size, heads=config.num_attention_heads,
heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob,
attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers,
num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range,
initializer_range=config.initializer_range, seed=seed,
seed=seed, fp16=fp16,
fp16=fp16, pre_layer_norm=preln)
pre_layer_norm=preln)
new_module = DeepSpeedTransformerLayer(cuda_config) new_module = DeepSpeedTransformerLayer(cuda_config)
...@@ -71,14 +66,7 @@ def module_inject(layer_obj, ...@@ -71,14 +66,7 @@ def module_inject(layer_obj,
setattr(model, name, copy.deepcopy(new_module)) setattr(model, name, copy.deepcopy(new_module))
else: else:
module_inject(layer_obj, module_inject(layer_obj, child, config, micro_batch_size, max_seq_length, seed, preln, fp16)
child,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16)
return model return model
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from deepspeed import comm as dist from deepspeed import comm as dist
...@@ -10,6 +13,7 @@ from deepspeed.accelerator import get_accelerator ...@@ -10,6 +13,7 @@ from deepspeed.accelerator import get_accelerator
class LinearAllreduce(nn.Module): class LinearAllreduce(nn.Module):
def __init__(self, weight, bias=None, mp_group=None): def __init__(self, weight, bias=None, mp_group=None):
super(LinearAllreduce, self).__init__() super(LinearAllreduce, self).__init__()
self.weight = weight self.weight = weight
...@@ -26,6 +30,7 @@ class LinearAllreduce(nn.Module): ...@@ -26,6 +30,7 @@ class LinearAllreduce(nn.Module):
class LinearLayer(nn.Module): class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__() super(LinearLayer, self).__init__()
if weight is not None: if weight is not None:
...@@ -33,9 +38,7 @@ class LinearLayer(nn.Module): ...@@ -33,9 +38,7 @@ class LinearLayer(nn.Module):
self.bias = bias self.bias = bias
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(weight_shape, torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name()))
dtype=dtype,
device=get_accelerator().current_device_name()))
self.bias = Parameter( self.bias = Parameter(
torch.empty(weight_shape[0], torch.empty(weight_shape[0],
...@@ -51,26 +54,35 @@ class LinearLayer(nn.Module): ...@@ -51,26 +54,35 @@ class LinearLayer(nn.Module):
class Normalize(nn.Module): class Normalize(nn.Module):
def __init__(self, dim, dtype=torch.float, eps=1e-5):
def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None):
super(Normalize, self).__init__() super(Normalize, self).__init__()
self.norm = nn.LayerNorm(dim, if weight is not None:
eps=eps).to(dtype).to( self.weight = weight
get_accelerator().current_device_name()) self.bias = bias
self.weight = self.norm.weight else:
self.bias = self.norm.bias self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
self.weight = self.norm.weight
self.bias = self.norm.bias
self.eps = eps
def forward(self, input): def forward(self, input):
return self.norm(input) return nn.functional.layer_norm(input, input.shape[-1:], self.weight, self.bias, eps=self.eps)
class EmbeddingLayer(nn.Module): class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.half):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(EmbeddingLayer, self).__init__() super(EmbeddingLayer, self).__init__()
self.weight = Parameter( if weight is None:
torch.empty(weight_shape[0], self.weight = Parameter(
weight_shape[1], torch.empty(weight_shape[0],
dtype=dtype, weight_shape[1],
device=get_accelerator().current_device_name())) dtype=dtype,
device=get_accelerator().current_device_name()))
else:
self.weight = weight
def forward(self, input): def forward(self, input):
return F.embedding(input, self.weight) return F.embedding(input, self.weight)
...@@ -80,20 +92,19 @@ class OPTEmbedding(EmbeddingLayer): ...@@ -80,20 +92,19 @@ class OPTEmbedding(EmbeddingLayer):
""" """
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
""" """
def __init__(self, weight_shape):
def __init__(self, weight_shape=None, weight=None, bias=None):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack # and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2 self.offset = 2
super().__init__(weight_shape) super().__init__(weight_shape, weight=weight)
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long() attention_mask = attention_mask.long()
# create positions depending on attention_mask # create positions depending on attention_mask
positions = (torch.cumsum(attention_mask, positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
dim=1).type_as(attention_mask) *
attention_mask).long() - 1
# cut positions if `past_key_values_length` is > 0 # cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:] positions = positions[:, past_key_values_length:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment