# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import json import math import torch from torch.autograd import Function #from ...inference.engine import inference_cuda_module, specialized_mode # Cuda modules will be imported if needed inference_cuda_module = None specialized_mode = None import torch.nn as nn from .ds_attention import DeepSpeedSelfAttention from .config import DeepSpeedInferenceConfig from ....moe.sharded_moe import TopKGate from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): """Initialize the DeepSpeed Transformer Config. Arguments: hidden_size: The hidden size of the transformer layer intermediate_size: The intermediate size of the feed-forward part of transformer layer heads: The number of heads in the self-attention of the transformer layer num_hidden_layers: The number of transformer layers layer_norm_eps: The epsilon value for the layer norm local_rank: Optional: The rank of GPU running the transformer kernel, it is not required to use if the model already set the current device, otherwise need to set it so that the transformer kernel can work on the right device mp_size (optional): This argument is mainly used to create the parameters on the kernel side using model-parallel architecture. If the client model already takes care of this, there is no need to pass this argument. fp16: Enable half-precision computation pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture stochastic_mode: Enable for high performance, please note that this flag has some level of non-determinism and can produce different results on different runs. However, we have seen that by enabling it, the pretraining tasks such as BERT are not affected and can obtain a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend to turn it off in order to be able to reproduce the same result through the regular kernel execution. scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation. return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor """ def __init__(self, hidden_size=-1, intermediate_size=-1, heads=-1, num_hidden_layers=-1, layer_norm_eps=1e-12, local_rank=-1, mp_size=1, fp16=False, q_int8=False, pre_layer_norm=True, stochastic_mode=False, scale_attention=True, triangular_masking=True, local_attention=False, window_size=256, return_tuple=True, moe_experts=1, global_experts=1, k=1, capacity_factor=1., eval_capacity_factor=1., min_capacity=1, noisy_gate_policy=None, drop_tokens=True, use_rts=False, mlp_type='standard', scale_attn_by_inverse_layer_idx=False): super(DeepSpeedMoEInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, q_int8, pre_layer_norm, stochastic_mode, scale_attention, triangular_masking, local_attention, window_size, return_tuple) self.moe_experts = moe_experts self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor self.min_capacity = min_capacity self.noisy_gate_policy = noisy_gate_policy self.drop_tokens = drop_tokens self.use_rts = use_rts self.global_experts = global_experts self.mlp_type = mlp_type self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx @classmethod def from_dict(cls, json_object): config = DeepSpeedInferenceConfig() for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) class DeepSpeedMLPFunction(Function): @staticmethod def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group, async_op): if config.q_int8: intermediate = inference_cuda_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon, q_scales[2], (q_groups * (2**merge_count)), config.pre_layer_norm) output = inference_cuda_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \ inference_cuda_module.fused_gemm_gelu_fp32 output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group, async_op=async_op) return output + output_b @staticmethod def backward(ctx, grad_output): raise RuntimeError('You are running with DeepSpeed Inference mode. \ Please switch to Training mode for running backward!') class DeepSpeedMoEMLP(nn.Module): def __init__(self, config, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False, mp_group=None): super(DeepSpeedMoEMLP, self).__init__() self.config = config self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size)) interm_size = self.config.intermediate_size // (1 if mp_group is None else dist.get_world_size(group=mp_group)) self.inter_w = nn.Parameter(torch.Tensor(self.config.hidden_size, interm_size)) self.inter_b = nn.Parameter(torch.Tensor(interm_size)) self.output_w = nn.Parameter(torch.Tensor((interm_size), self.config.hidden_size)) self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) # used for quantization self.q_scales = q_scales self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups self.merge_count = int(math.log2(merge_count)) self.mp_group = mp_group def forward(self, input, async_op=False): return DeepSpeedMLPFunction.apply(input, self.inter_w, self.inter_b, self.config, self.output_b, self.output_w, self.q_scales, self.q_groups, self.merge_count, self.mp_group, async_op) class DeepSpeedMoEInference(nn.Module): """Initialize the DeepSpeed MoE Transformer Layer. Arguments: layer_id: The layer index starting from 0, e.g. if model has 24 transformer layers, layer_id will be 0,1,2...23 when each layer object is instantiated config: An object of DeepSpeedInferenceConfig mp_group: Model parallelism group initialized on the modeling side. quantize_scales: This argument groups all the layers' scales used for quantization quantize_groups: Number of groups used for quantizing the model merge_count: Shows the number of model-parallel checkpoints merged before running inference. We use this argument to control the quantization scale for the model parameters if a bigger quantize-grouping than 1 is used. mlp_extra_grouping: This flag is used to show a 2x higher number of groups used for the MLP part of a Transformer layer. We use this feature for quantization to reduce the convergence impact for specific downstream tasks. """ layer_id = 0 def __init__(self, config, mp_group=None, ep_group=None, expert_mp_group=None, quantize_scales=None, quantize_groups=1, merge_count=1, mlp_extra_grouping=False): super(DeepSpeedMoEInference, self).__init__() self.config = config self.config.layer_id = DeepSpeedMoEInference.layer_id global inference_cuda_module global specialized_mode if inference_cuda_module is None: specialized_mode = False # InferenceSpecializedBuilder is not among DeepSpeed provided builder yet, so we infer by builder name string builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder") if builder != None and builder.is_compatible(): inference_cuda_module = builder.load() specialized_mode = True else: inference_cuda_module = InferenceBuilder().load() self.config.specialized_mode = specialized_mode DeepSpeedMoEInference.layer_id += 1 self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count) self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) if config.mlp_type == 'residual': self.res_mlp = DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping, mp_group) self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2)) self.coef_func = inference_cuda_module.softmax_fp16 if self.config.fp16 or self.config.q_int8 else \ inference_cuda_module.softmax_fp32 self.vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ inference_cuda_module.vector_matmul_fp32 config.mp_size = 1 self.mlp = nn.ModuleList( DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping, expert_mp_group) for i in range(self.config.moe_experts)) self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k, self.config.capacity_factor, self.config.eval_capacity_factor, self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens, self.config.use_rts) self.ep_group = ep_group self.mp_group = mp_group self.expert_mp_group = expert_mp_group print("DeepSpeed MoE Transformer Inference config is ", self.config.__dict__) self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \ inference_cuda_module.bias_residual_fp32 self.ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \ inference_cuda_module.layer_norm_fp32 self.einsum_sec_sm_ecm = inference_cuda_module.einsum_sec_sm_ecm_fp16 if self.config.fp16 or self.config.q_int8 else \ inference_cuda_module.einsum_sec_sm_ecm_fp32 def res_coef_func(self, inp, async_op): inp = self.vector_matmul_func(inp, self.res_coef, async_op) return self.coef_func(inp, torch.empty(1), False, False, False, 256, async_op) def moe_gate_einsum(self, attention_output): _, combined_weights, dispatch_mask, _ = self.moe_gate( attention_output.view(-1, self.config.hidden_size), None, ) dispatched_attention = self.einsum_sec_sm_ecm(dispatch_mask.type_as(attention_output), attention_output.view(-1, self.config.hidden_size)) return dispatched_attention, combined_weights def expert_exec(self, dispatched_input): dispatched_input = dispatched_input.reshape(self.config.global_experts // self.config.moe_experts, self.config.moe_experts, -1, self.config.hidden_size) chunks = dispatched_input.chunk(self.config.moe_experts, dim=1) expert_outputs = torch.empty(( self.config.moe_experts, chunks[0].shape[0], ) + chunks[0].shape[2:], dtype=dispatched_input.dtype, device=dispatched_input.device) for chunk, expert in zip(chunks, range(len(self.mlp))): expert_outputs[expert] = self.mlp[expert](chunk.view(-1, dispatched_input.shape[-2], dispatched_input.shape[-1])) return expert_outputs def _alltoall(self, dispatched_attention): if dist.get_world_size(group=self.ep_group) > 1: dispatched_input = torch.empty_like(dispatched_attention) dist.all_to_all_single(dispatched_input, dispatched_attention, group=self.ep_group) return dispatched_input else: return dispatched_attention def scale_expert_output(self, attention_output, expert_output, combined_weights): combined_output = torch.matmul( combined_weights.type_as(attention_output).reshape(combined_weights.shape[0], -1), expert_output.reshape(-1, expert_output.shape[-1])) return combined_output.reshape(attention_output.shape) def forward(self, input, input_mask=None, attention_mask=None, head_mask=None, layer_past=None, get_key_value=False, get_present=False, encoder_output=None, enc_dec_attn_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, output_attentions=False): get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask input_type = input.dtype if (self.config.fp16 or self.config.q_int8) \ and input.dtype == torch.float: input = input.half() with torch.no_grad(): attention_output = self.attention(input, input_mask, head_mask, layer_past, get_present, encoder_hidden_states, encoder_attention_mask, output_attentions, self.norm_w, self.norm_b) if get_present: attention_output, p_key, p_value = attention_output[0:3] presents = (p_key, p_value) elif output_attentions: attention_output, _, _, context_output = attention_output[0:4] else: attention_output = attention_output[0] residual_add = attention_output + self.attention.attn_ob attention_output = self.ds_layernorm(residual_add, self.attn_nw, self.attn_nb, self.config.epsilon) if self.config.mlp_type == 'residual': res_mlp_out = self.res_mlp(attention_output, async_op=True) res_coef_out = self.res_coef_func(attention_output, async_op=True) if self.expert_mp_group is not None: tensor_list = [ torch.empty_like(attention_output) for _ in range(dist.get_world_size(group=self.expert_mp_group)) ] tensor_list[dist.get_rank(group=self.expert_mp_group)] = attention_output dist.all_gather(tensor_list, attention_output, group=self.expert_mp_group) attention_output = torch.cat(tensor_list).contiguous() ############## MoE Gating + Experts ############### dispatched_attention, combined_weights = self.moe_gate_einsum(attention_output) dispatched_input = self._alltoall(dispatched_attention) expert_outputs = self.expert_exec(dispatched_input) expert_output = self._alltoall(expert_outputs) output = self.scale_expert_output(attention_output, expert_output, combined_weights) ################################################ if self.expert_mp_group is not None: output = output.split(output.shape[0] // dist.get_world_size(group=self.expert_mp_group), dim=0)[dist.get_rank(group=self.expert_mp_group)] if self.config.mlp_type == 'residual': inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out, output) output = self.bias_residual_func(output, residual_add, torch.empty(1)) if not self.config.pre_layer_norm: output = self.ds_layernorm(output, self.norm_w, self.norm_b, self.config.epsilon) if input_type != output.dtype: output = output.to(input_type) if get_present: output = (output, presents) if self.config.return_tuple: return output if type(output) is tuple else (output, ) else: return output