Unverified Commit b4ad815e authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #891 from kvcache-ai/performance-optimize-gpu

use compile for gate, slight performance improvement
parents 6c4ed591 a889288f
from typing import Any, Union
import numpy as np
import numpy.typing as npt
from torch import Tensor, nn
import torch.nn.functional as F
from typing import Optional
from torch import nn
import torch
import sys, os
import torch.nn.functional as F
import os
from ktransformers.operators.base_operator import BaseInjectedModule
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.linear import KTransformersLinear
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod
import time
# class Base(BaseInjectedModule, ABC):
......@@ -100,8 +89,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
prefill_device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
......@@ -131,3 +120,133 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
# This is used by the Deepseek-V2 and Deepseek-V3 model
#@torch.compile(dynamic=True)
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "sigmoid",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str| None = "KLinearMarlin",
use_quant: bool = False,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
self.generate_op = generate_op
self.prefill_op = prefill_op
self.is_windows = os.name == 'nt'
self.use_quant = use_quant
if not self.is_windows and use_quant:
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
gguf_loader, config, self.gate_linear, #orig_module
generate_device, generate_op, prefill_device, prefill_op)
else:
self.gate_linear = None
def forward(self, hidden_states) -> torch.Tensor:
if self.is_windows:
return self.orig_module.forward(hidden_states)
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
if self.use_quant:
logits = self.gate_linear.forward(logits)
else:
logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
)
return grouped_topk(hidden_states, logits,
self.top_k, self.norm_topk_prob,
self.n_group, self.topk_group)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
if not self.is_windows and self.use_quant:
self.gate_linear.load(self.orig_module.weight)
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
......@@ -477,7 +477,6 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda",
......
......@@ -26,7 +26,7 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......
......@@ -147,7 +147,7 @@
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -157,7 +157,7 @@
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......@@ -167,7 +167,7 @@
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
......@@ -177,7 +177,7 @@
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
......
......@@ -278,7 +278,7 @@
name: "^model\\.layers\\.([0-7])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -288,7 +288,7 @@
name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......@@ -298,7 +298,7 @@
name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
......@@ -308,7 +308,7 @@
name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
......@@ -318,7 +318,7 @@
name: "^model\\.layers\\.(3[2-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:4"
prefill_device: "cuda:4"
......@@ -328,7 +328,7 @@
name: "^model\\.layers\\.(4[0-7])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:5"
prefill_device: "cuda:5"
......@@ -338,7 +338,7 @@
name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:6"
prefill_device: "cuda:6"
......@@ -348,7 +348,7 @@
name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
......
......@@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -18,7 +18,7 @@
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
......@@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......
......@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
......@@ -38,7 +38,7 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment