Unverified Commit 877aec85 authored by Yuhao Tsui's avatar Yuhao Tsui Committed by GitHub
Browse files

Merge branch 'kvcache-ai:main' into main

parents 84164f58 9037bf30
# coding=utf-8
# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepSeekV3 model configuration"""
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging
logger = logging.get_logger(__name__)
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV3Config(PretrainedConfig): class DeepseekV3Config(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3. defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 129280): vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV3Model`] `inputs_ids` passed when calling [`DeepseekV3Model`]
hidden_size (`int`, *optional*, defaults to 7168): hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations. Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 18432): intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations. Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 2048): moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations. Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61): num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder. Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 128): num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128): n_shared_experts (`int`, *optional*, defaults to None):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If Number of shared experts, None means dense model.
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if n_routed_experts (`int`, *optional*, defaults to None):
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When Number of routed experts, None means dense model.
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed routed_scaling_factor (`float`, *optional*, defaults to 1.0):
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts.
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts. Scaling factor or routed experts.
kv_lora_rank (`int`, *optional*, defaults to 512): topk_method (`str`, *optional*, defaults to `gready`):
Rank of the LoRA matrices for key and value projections. Topk method used in routed gate.
q_lora_rank (`int`, *optional*, defaults to 1536): n_group (`int`, *optional*, defaults to None):
Rank of the LoRA matrices for query projections.
qk_rope_head_dim (`int`, *optional*, defaults to 64):
Dimension of the query/key heads that use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts. Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4): topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8): num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model. Number of selected experts, None means dense model.
first_k_dense_replace (`int`, *optional*, defaults to 3): moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/ \--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`): norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts. Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001): aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient. Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample. Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder. The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096): max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
...@@ -98,15 +75,10 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -98,15 +75,10 @@ class DeepseekV3Config(PretrainedConfig):
relevant if `config.is_decoder=True`. relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*): pad_token_id (`int`, *optional*):
Padding token id. Padding token id.
bos_token_id (`int`, *optional*, defaults to 0): bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id. Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 1): eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id. End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`): tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0): rope_theta (`float`, *optional*, defaults to 10000.0):
...@@ -120,49 +92,44 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -120,49 +92,44 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
```python ```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config >>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration >>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config() >>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "deepseek_v3" model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
}
def __init__( def __init__(
self, self,
vocab_size=129280, vocab_size=129280,
hidden_size=7168, hidden_size=7168,
intermediate_size=18432, intermediate_size=18432,
moe_intermediate_size=2048, moe_intermediate_size = 2048,
num_hidden_layers=61, num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128, num_attention_heads=128,
num_key_value_heads=128, num_key_value_heads=128,
n_shared_experts=1, n_shared_experts = 1,
n_routed_experts=256, n_routed_experts = 256,
routed_scaling_factor=2.5, ep_size = 1,
kv_lora_rank=512, routed_scaling_factor = 2.5,
q_lora_rank=1536, kv_lora_rank = 512,
qk_rope_head_dim=64, q_lora_rank = 1536,
v_head_dim=128, qk_rope_head_dim = 64,
qk_nope_head_dim=128, v_head_dim = 128,
n_group=8, qk_nope_head_dim = 128,
topk_group=4, topk_method = 'noaux_tc',
num_experts_per_tok=8, n_group = 8,
first_k_dense_replace=3, topk_group = 4,
norm_topk_prob=True, num_experts_per_tok = 8,
aux_loss_alpha=0.001, moe_layer_freq = 1,
first_k_dense_replace = 3,
norm_topk_prob = True,
scoring_func = 'sigmoid',
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096, max_position_embeddings=4096,
initializer_range=0.02, initializer_range=0.02,
...@@ -171,7 +138,6 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -171,7 +138,6 @@ class DeepseekV3Config(PretrainedConfig):
pad_token_id=None, pad_token_id=None,
bos_token_id=0, bos_token_id=0,
eos_token_id=1, eos_token_id=1,
pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
rope_scaling=None, rope_scaling=None,
...@@ -185,24 +151,25 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -185,24 +151,25 @@ class DeepseekV3Config(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim self.qk_nope_head_dim = qk_nope_head_dim
self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim self.topk_method = topk_method
self.head_dim = qk_rope_head_dim
self.n_group = n_group self.n_group = n_group
self.topk_group = topk_group self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob self.norm_topk_prob = norm_topk_prob
self.aux_loss_alpha = aux_loss_alpha self.scoring_func = scoring_func
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
num_key_value_heads = num_attention_heads num_key_value_heads = num_attention_heads
...@@ -211,17 +178,11 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -211,17 +178,11 @@ class DeepseekV3Config(PretrainedConfig):
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
...@@ -229,7 +190,4 @@ class DeepseekV3Config(PretrainedConfig): ...@@ -229,7 +190,4 @@ class DeepseekV3Config(PretrainedConfig):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
\ No newline at end of file
__all__ = ["DeepseekV3Config"]
\ No newline at end of file
...@@ -8,9 +8,11 @@ Version : 0.1.0 ...@@ -8,9 +8,11 @@ Version : 0.1.0
# Copyright 2018- The Hugging Face team. All rights reserved. # Copyright 2018- The Hugging Face team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved. # Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
import torch import torch
import torch.nn as nn
import transformers import transformers
from transformers import Cache, PretrainedConfig from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext
class StaticCache(transformers.StaticCache): class StaticCache(transformers.StaticCache):
""" """
Static Cache class to be used with `torch.compile(model)`. Static Cache class to be used with `torch.compile(model)`.
...@@ -188,3 +190,85 @@ class StaticCache(transformers.StaticCache): ...@@ -188,3 +190,85 @@ class StaticCache(transformers.StaticCache):
def get_max_cache_shape(self) -> Tuple[int, int, int, int]: def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache.""" """Returns the maximum shape of the cache."""
return self.max_cache_len return self.max_cache_len
class KDeepSeekV3Cache(nn.Module):
def __init__(
self,
config: PretrainedConfig,
page_size: int = 256,
dtype=torch.bfloat16,
device=torch.device("cuda:0"),
):
super().__init__()
self.config = config
self.dtype = dtype
self.device = device
self.kv_lora_rank = config.kv_lora_rank
self.page_size = page_size
self.k_caches = []
self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
inference_context.k_cache[0][i]
)
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
k_out = self.k_caches[layer_idx]
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])
return k_out
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
page_offset = cache_position % self.page_size
page_idx_local = cache_position // self.page_size
query_ids = torch.zeros_like(cache_position)
for i in range(len(q_indptr) - 1):
start_idx = q_indptr[i]
end_idx = q_indptr[i + 1]
query_ids[start_idx:end_idx] = i
page_idx = torch.zeros_like(page_idx_local)
for i in range(bsz_tensors[0]):
query_id = query_ids[i]
local_block = page_idx_local[i]
start_block = kv_indptr[query_id]
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
page_idx[i] = kv_indices[start_block + local_block]
return page_idx, page_offset
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.models.custom_cache import KDeepSeekV3Cache
from ktransformers.models.modeling_deepseek import DeepseekV2Model, DeepseekV2PreTrainedModel
from ktransformers.models.configuration_deepseek import DeepseekV2Config
torch.set_grad_enabled(False)
torch.set_default_dtype(torch.bfloat16)
import flashinfer
class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
kv_cache: KDeepSeekV3Cache
use_cuda_graph = False
def __init__(
self,
config,
kv_cache,
):
super().__init__(config)
self.model = DeepseekV2Model(config)
self.config = config
self.kv_cache = kv_cache
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
self.use_cuda_graph = use_cuda_graph
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
self.paged_kv_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf
)
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
features = []
for i in range(batch.batch_size):
tokens = batch.minibatch.tokens.contiguous()
feature = (
self.model.embed_tokens(tokens.to(torch.device('cpu')))
.to(torch.bfloat16)
.to(device=device)
)
features.append(feature)
return features
def forward(
self,
batch: ForwardBatchInput | None = None,
features: List[torch.Tensor] | None = None,
bsz_tensors: torch.Tensor | None = None,
num_tokens_tensors: torch.Tensor | None = None,
page_idx: torch.Tensor | None = None,
page_offset: torch.Tensor | None = None,
) -> ForwardBatchOutput:
current_stream = torch.cuda.current_stream()
forward_batch_output = ForwardBatchOutput()
hidden_states = features[0]
with torch.cuda.stream(current_stream):
residual = torch.zeros_like(hidden_states)
for i, decode_layer in enumerate(self.model.layers):
if self.model.transfer_map is not None and i in self.model.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.model.transfer_map[i]
if cur_device not in self.model.stream_device_map:
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
hidden_states = hidden_states.to(
self.model.transfer_map[i], non_blocking=True
)
batch.minibatch.position_ids = (
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
if batch.minibatch.position_ids is not None
else None
)
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
hidden_states = decode_layer.self_attn(hidden_states, self.kv_cache,
position_ids=batch.minibatch.position_ids,
wrapper=self.wrapper, bsz_tensors=num_tokens_tensors,
cache_position=batch.minibatch.positions,
batch_indices=batch.minibatch.batch_indices,
kv_indices=batch.minibatch.kv_indices,
kv_indptr=batch.minibatch.kv_indptr,
kv_last_page_len=batch.minibatch.kv_last_page_len,
q_indptr=batch.minibatch.q_indptr,
page_idx=page_idx,
page_offset=page_offset
)
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
if i < 3:
hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)
else:
hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors)
hidden_states = hidden_states.squeeze(0)
forward_batch_output = ForwardBatchOutput()
assert batch.batch_size == 1
with torch.cuda.stream(current_stream):
local_logit = self.lm_head(self.model.norm(hidden_states[batch.minibatch.logits_start], num_tokens_tensors, residual[batch.minibatch.logits_start])[0])
# local_logit = local_logit[batch.minibatch.logits_start]
forward_batch_output.logits.append(local_logit)
return forward_batch_output
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,):
minibatch = batch.minibatch
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type)
\ No newline at end of file
"""
Date: 2024-11-06 10:05:11
LastEditors: djw
LastEditTime: 2024-11-13 07:50:51
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.models.custom_cache import KDeepSeekV3Cache
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model, DeepseekV3PreTrainedModel
from ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config
torch.set_grad_enabled(False)
torch.set_default_dtype(torch.bfloat16)
import flashinfer
class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
cache: KDeepSeekV3Cache
use_cuda_graph = False
def __init__(
self,
config: DeepseekV3Config,
cache,
):
super().__init__(config)
self.model = DeepseekV3Model(config)
self.config = config
self.cache = cache
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
self.use_cuda_graph = use_cuda_graph
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf
)
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
features = []
for i in range(batch.batch_size):
tokens = batch.minibatch.tokens.contiguous()
feature = (
self.model.embed_tokens(tokens.to(torch.device('cpu')))
.to(torch.bfloat16)
.to(device=device)
)
features.append(feature)
return features
def forward(
self,
batch: ForwardBatchInput | None = None,
features: List[torch.Tensor] | None = None,
bsz_tensors: torch.Tensor | None = None,
num_tokens_tensors: torch.Tensor | None = None,
page_idx: torch.Tensor | None = None,
page_offset: torch.Tensor | None = None,
cuda_graph_idx: int | None = -1
) -> ForwardBatchOutput:
current_stream = torch.cuda.current_stream()
forward_batch_output = ForwardBatchOutput()
hidden_states = features[0]
with torch.cuda.stream(current_stream):
residual = torch.zeros_like(hidden_states)
for i, decode_layer in enumerate(self.model.layers):
# can't use now, only one flashinfer wrapper
if self.model.transfer_map is not None and i in self.model.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.model.transfer_map[i]
if cur_device not in self.model.stream_device_map:
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
hidden_states = hidden_states.to(
self.model.transfer_map[i], non_blocking=True
)
batch.minibatch.position_ids = (
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
if batch.minibatch.position_ids is not None
else None
)
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
position_ids=batch.minibatch.position_ids,
wrapper=self.wrapper, num_tokens_tensors=num_tokens_tensors,
page_idx=page_idx,
page_offset=page_offset
)
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
if i < self.config.first_k_dense_replace:
hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)
else:
hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)
hidden_states = hidden_states.squeeze(0)
forward_batch_output = ForwardBatchOutput()
with torch.cuda.stream(current_stream):
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
forward_batch_output.logits.append(local_logit)
return forward_batch_output
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
num_heads: int,
head_dim_ckv: int,
head_dim_kpe: int,
page_size: int,
causal: bool,
sm_scale: float,
q_data_type: torch.dtype,
kv_data_type: torch.dtype,):
minibatch = batch.minibatch
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
\ No newline at end of file
...@@ -99,6 +99,7 @@ class DeepseekV3RMSNorm(nn.Module): ...@@ -99,6 +99,7 @@ class DeepseekV3RMSNorm(nn.Module):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
self.hidden_size = hidden_size
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
...@@ -398,7 +399,6 @@ class MoEGate(nn.Module): ...@@ -398,7 +399,6 @@ class MoEGate(nn.Module):
self.n_routed_experts = config.n_routed_experts self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.scoring_func = config.scoring_func self.scoring_func = config.scoring_func
self.seq_aux = config.seq_aux
self.topk_method = config.topk_method self.topk_method = config.topk_method
self.n_group = config.n_group self.n_group = config.n_group
self.topk_group = config.topk_group self.topk_group = config.topk_group
...@@ -436,6 +436,7 @@ class MoEGate(nn.Module): ...@@ -436,6 +436,7 @@ class MoEGate(nn.Module):
### select top-k experts ### select top-k experts
if self.topk_method == "noaux_tc": if self.topk_method == "noaux_tc":
#assert not self.training
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
group_scores = ( group_scores = (
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
...@@ -454,7 +455,7 @@ class MoEGate(nn.Module): ...@@ -454,7 +455,7 @@ class MoEGate(nn.Module):
) )
.reshape(bsz * seq_len, -1) .reshape(bsz * seq_len, -1)
) # [n, e] ) # [n, e]
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
_, topk_idx = torch.topk( _, topk_idx = torch.topk(
tmp_scores, k=self.top_k, dim=-1, sorted=False tmp_scores, k=self.top_k, dim=-1, sorted=False
) )
...@@ -1933,4 +1934,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): ...@@ -1933,4 +1934,4 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
\ No newline at end of file
...@@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding( ...@@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding(
self.orig_module.rope_type, self.orig_module.rope_type,
self.orig_module.config, self.orig_module.config,
) )
class RotaryEmbeddingV4(BaseInjectedModule):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def load(self):
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
)
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
\ No newline at end of file
...@@ -32,7 +32,8 @@ import os ...@@ -32,7 +32,8 @@ import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled: if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
from flashinfer.mla import BatchMLAPagedAttentionWrapper
from ktransformers.models.custom_cache import KDeepSeekV3Cache
logger = logging.getLogger("attention") logger = logging.getLogger("attention")
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
...@@ -421,6 +422,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -421,6 +422,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if q_len == 1: if q_len == 1:
self.mla_wrapper.plan(None,None,None, self.mla_wrapper.plan(None,None,None,
position_ids.squeeze(1)+1, position_ids.squeeze(1)+1,
None,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
self.qk_rope_head_dim, self.qk_rope_head_dim,
...@@ -433,6 +435,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -433,6 +435,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device) kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
self.mla_wrapper.plan(qo_indptr,None,None, self.mla_wrapper.plan(qo_indptr,None,None,
kv_len_arr, kv_len_arr,
None,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
self.qk_rope_head_dim, self.qk_rope_head_dim,
...@@ -759,3 +762,92 @@ class KLlamaAttention(BaseInjectedModule): ...@@ -759,3 +762,92 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
wrapper: BatchMLAPagedAttentionWrapper,
num_tokens_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states, num_tokens_tensors)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
q = q.view(q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = compressed_kv.contiguous()
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
q_pe = q_pe.squeeze(0)
if kv_cache is not None:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(0, 1)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
attn_output = attn_output.transpose(0, 1)
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, num_tokens_tensors)
return attn_output
...@@ -37,6 +37,10 @@ import time ...@@ -37,6 +37,10 @@ import time
from ktransformers.operators.cpuinfer import CPUInfer from ktransformers.operators.cpuinfer import CPUInfer
def deduplicate_and_sort(lst):
return sorted(set(lst))
#cuda_graphs = [Config().chunk_size]
cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
# class Base(BaseInjectedModule, ABC): # class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC): class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs): def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
...@@ -112,6 +116,7 @@ class KExpertsBase(ABC): ...@@ -112,6 +116,7 @@ class KExpertsBase(ABC):
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device) tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
return tensors return tensors
class KExpertsCPU(KExpertsBase): class KExpertsCPU(KExpertsBase):
input_tensor_cpu:Tensor = None input_tensor_cpu:Tensor = None
expert_ids_cpu:Tensor = None expert_ids_cpu:Tensor = None
...@@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase): ...@@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase):
output_cpu:Tensor = None output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None # @TODO add yaml
CPU_INFER = None CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
key: str, key: str,
...@@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase): ...@@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase):
**kwargs **kwargs
): ):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
if KExpertsCPU.CPU_INFER is None:
KExpertsCPU.CPU_INFER = CPUInfer(Config().cpu_infer)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self.gguf_loader = gguf_loader
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.out_device = out_device self.out_device = out_device
...@@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase): ...@@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr = ctypes.addressof( down_ptr = ctypes.addressof(
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
) )
#print(self.gate_type, self.up_type, self.down_type) # print(self.gate_qtype, self.up_qtype, self.down_qtype)
n_routed_experts = self.n_routed_experts n_routed_experts = self.n_routed_experts
# n_routed_experts = len(self.orig_module) # n_routed_experts = len(self.orig_module)
moe_config = MOEConfig( moe_config = MOEConfig(
...@@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase): ...@@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.submit(self.moe.warm_up()) self.cpu_infer.submit(self.moe.warm_up())
self.cpu_infer.sync() self.cpu_infer.sync()
if self.out_device not in KExpertsCPU.output_gpu_map: if self.out_device not in KExpertsCPU.output_gpu_map:
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device) if isinstance(cuda_graphs, list):
KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]
else:
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)
if KExpertsCPU.input_tensor_cpu == None: if KExpertsCPU.input_tensor_cpu == None:
KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True) if isinstance(cuda_graphs, list):
KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]
KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]
else:
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights): def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) if bsz_tensor is None:
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) if cuda_graph_idx != -1:
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr())) KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
def sync_for_one_decode(self): KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream) KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
return KExpertsCPU.output_gpu_map[self.out_device] else:
def forward(self, input_tensor, expert_ids, weights):
# generate, capture and run cuda graph
# print(expert_ids)
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True) KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True) KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True) KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr())) KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream) self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
def sync_for_one_decode(self, cuda_graph_idx=0):
if cuda_graph_idx != -1:
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
else:
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True) KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device] return KExpertsCPU.output_gpu_map[self.out_device]
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
if bsz_tensor is None:
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
if torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
else:
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
else: else:
input_tensor = input_tensor.contiguous().cpu() input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu() expert_ids = expert_ids.contiguous().cpu()
weights = weights.contiguous().to(torch.float32).cpu() weights = weights.contiguous().to(torch.float32).cpu()
bsz_tensor = bsz_tensor.contiguous().cpu()
output = torch.empty_like(input_tensor).contiguous() output = torch.empty_like(input_tensor).contiguous()
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr())) self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
self.cpu_infer.sync() self.cpu_infer.sync()
return output.to(device=object.__getattribute__(self, "out_device")) return output.to(device=object.__getattribute__(self, "out_device"))
...@@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y += y_ y += y_
return y return y
@torch.no_grad() @torch.no_grad()
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight) outs = self.experts(x, topk_ids, topk_weight)
...@@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock): ...@@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
# the `top_x` tensor here. # the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
return final_hidden_states return final_hidden_states
\ No newline at end of file
class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
prefill_device:str = "cuda",
prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
self.generate_experts = None
if prefill_op is not None:
self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
else:
self.prefill_experts = None
self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
# TODO support w as input
if not mode: mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
self.prefill_experts.unload()
self.generate_experts.load(w, warmup=warmup)
self.device = self.generate_experts.device
self.mode = mode
elif mode == InferenceState.PREFILL:
self.generate_experts.unload()
self.prefill_experts.load(w, warmup=warmup)
self.device = self.prefill_experts.device
self.mode = mode
elif mode == InferenceState.UNLOAD:
self.unload()
self.mode = mode
self.device = self.generate_experts.device
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
def unload(self):
if self.generate_experts is not None:
self.generate_experts.unload()
if self.prefill_experts is not None:
self.prefill_experts.unload()
self.device = self.generate_experts.device
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
else:
raise ValueError("load or set_inference_mode before forward")
def set_inference_mode(self, mode: InferenceState):
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE, warmup=False)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL, warmup=False)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
...@@ -86,6 +86,7 @@ class MLAWrapper(): ...@@ -86,6 +86,7 @@ class MLAWrapper():
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device) self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device) self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device) self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
else: else:
self.qo_indptr_buf = None self.qo_indptr_buf = None
...@@ -94,19 +95,22 @@ class MLAWrapper(): ...@@ -94,19 +95,22 @@ class MLAWrapper():
self.kv_len_arr_buf = None self.kv_len_arr_buf = None
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.float_workspace_buffer, self.float_workspace_buffer,
use_cuda_graph=False, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf, qo_indptr=self.qo_indptr_buf,
kv_indptr=self.kv_indptr_buf, kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf, kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf, kv_len_arr=self.kv_len_arr_buf,
bsz_tensor=self.batch_size_tensor_buf
) )
self.need_plan = True self.need_plan = True
def plan(self, def plan(self,
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_len_arr, kv_len_arr,
bsz_tensor,
num_heads, num_heads,
head_dim_ckv, head_dim_ckv,
head_dim_kpe, head_dim_kpe,
...@@ -124,6 +128,9 @@ class MLAWrapper(): ...@@ -124,6 +128,9 @@ class MLAWrapper():
if kv_indices is None: if kv_indices is None:
assert self.max_batch_size == 1 assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf kv_indices = self.kv_indices_buf
if bsz_tensor is None:
assert self.max_batch_size == 1
bsz_tensor = self.batch_size_tensor_buf
self.wrapper.plan( self.wrapper.plan(
qo_indptr, qo_indptr,
...@@ -138,6 +145,7 @@ class MLAWrapper(): ...@@ -138,6 +145,7 @@ class MLAWrapper():
sm_scale, sm_scale,
q_data_type, q_data_type,
kv_data_type, kv_data_type,
bsz_tensor
) )
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False): def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
...@@ -161,6 +169,7 @@ class MLAWrapperSingleton(): ...@@ -161,6 +169,7 @@ class MLAWrapperSingleton():
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_len_arr, kv_len_arr,
bsz_tensor,
num_heads, num_heads,
head_dim_ckv, head_dim_ckv,
head_dim_kpe, head_dim_kpe,
...@@ -174,6 +183,7 @@ class MLAWrapperSingleton(): ...@@ -174,6 +183,7 @@ class MLAWrapperSingleton():
kv_indptr, kv_indptr,
kv_indices, kv_indices,
kv_len_arr_cur_device, kv_len_arr_cur_device,
bsz_tensor,
num_heads, num_heads,
head_dim_ckv, head_dim_ckv,
head_dim_kpe, head_dim_kpe,
...@@ -240,16 +250,17 @@ if __name__ == "__main__": ...@@ -240,16 +250,17 @@ if __name__ == "__main__":
#checksame() #checksame()
#exit(0) #exit(0)
max_batch_size = 1 max_batch_size = 2
max_pages = 64 max_batch_tokens = 256
max_pages = 128
page_size = 64 page_size = 64
num_heads = 128 num_heads = 128
# warm-up # warm-up
kv_len = 4023 kv_len = 4023
q_len = 1 q_len = 1
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda") q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda") q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda")
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda") kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1) ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
...@@ -260,13 +271,19 @@ if __name__ == "__main__": ...@@ -260,13 +271,19 @@ if __name__ == "__main__":
max_pages, max_pages,
) )
used_pages = (kv_len + page_size - 1)// page_size
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda") kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda") qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda")
bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda")
wrapper.plan( wrapper.plan(
qo_indptr, qo_indptr,
None, kv_indptr,
None, kv_indices,
kv_len_arr, kv_len_arr,
bsz_tensor,
128, 128,
512, 512,
64, 64,
...@@ -276,14 +293,98 @@ if __name__ == "__main__": ...@@ -276,14 +293,98 @@ if __name__ == "__main__":
torch.bfloat16, torch.bfloat16,
) )
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)
print(attn_output.shape) print(attn_output.shape)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe) attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
graph.replay()
q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
1,
q[:q_len],
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)
# warm-up finished # warm-up finished
kv_len = 512
q_len = 128
pages = max_pages
used_pages = (kv_len + page_size - 1)// page_size
q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_nope[q_len:] = q_nope[:q_len]
q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_pe[q_len:] = q_pe[:q_len]
kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda")
kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda")
bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
q_nope_buf.copy_(q_nope)
q_pe_buf.copy_(q_pe)
kv_buf[:pages].copy_(kv_cache)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
# ref_torch
q = torch.cat([q_nope, q_pe], dim=-1)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
max_batch_size,
q,
k[:2*kv_len],
v[:2*kv_len],
True,
192 ** (-0.5)
)
torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)
torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)
torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)
#torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
#torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)
exit(0)
for forward_id in range(0, 1): for forward_id in range(0, 1):
print("forward_id", forward_id) print("forward_id", forward_id)
for layer_id in range(1): for layer_id in range(1):
...@@ -376,5 +477,4 @@ if __name__ == "__main__": ...@@ -376,5 +477,4 @@ if __name__ == "__main__":
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt" #file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name) #ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3) #torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print("test past") print("test past")
\ No newline at end of file
...@@ -122,132 +122,3 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ...@@ -122,132 +122,3 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.e_score_correction_bias = None self.e_score_correction_bias = None
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
# This is used by the Deepseek-V2 and Deepseek-V3 model
#@torch.compile(dynamic=True)
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
routed_scaling_factor: float = 1.0,
scoring_func: str = "sigmoid",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)
#float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if topk > 1 and renormalize:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights = topk_weights / denominator
topk_weights = topk_weights * routed_scaling_factor # must multiply the scaling factor
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str| None = "KLinearMarlin",
use_quant: bool = False,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
self.generate_op = generate_op
self.prefill_op = prefill_op
self.is_windows = os.name == 'nt'
self.use_quant = use_quant
if not self.is_windows and use_quant:
print("injecting gate_linear")
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
gguf_loader, config, self.gate_linear, #orig_module
generate_device, generate_op, prefill_device, prefill_op)
else:
self.gate_linear = None
def forward(self, hidden_states) -> torch.Tensor:
if True or self.is_windows:
return self.orig_module.forward(hidden_states)
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
if self.use_quant:
logits = self.gate_linear.forward(hidden_states)
else:
logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
)
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob, self.n_group,
self.topk_group, self.routed_scaling_factor, "sigmoid", self.e_score_correction_bias)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
if not self.is_windows and self.use_quant:
self.gate_linear.load(self.orig_module.weight)
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
'''
Date: 2024-11-13 15:05:52
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:19
'''
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for normalization layers."""
import logging
from typing import Optional, Tuple, Union
from transformers import PretrainedConfig
import torch
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from flashinfer.norm import (
fused_add_rmsnorm,
rmsnorm,
)
logger = logging.getLogger(__name__)
class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
orig_module.variance_epsilon)
def forward(
self,
x: torch.Tensor,
batch_size_tensor: torch.Tensor = None,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
#return self.forward_native(x, residual)
if batch_size_tensor is None:
return self.forward_native(x)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return x, residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
return out
def forward_native(
self, hidden_states
):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
\ No newline at end of file
...@@ -15,14 +15,16 @@ import ctypes ...@@ -15,14 +15,16 @@ import ctypes
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
import KTransformersOps import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import ( from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace, MarlinWorkspace,
marlin_quantize, marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K, GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MAX_PARALLEL,
vllm_marlin_quantize
) )
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -84,8 +86,10 @@ class KLinearBase(ABC): ...@@ -84,8 +86,10 @@ class KLinearBase(ABC):
if self.gguf_loader.safetensor_loader is not None: if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader # using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight') tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv') if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv) weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
return nn.Parameter(tensor)
elif key + ".weight" in self.gguf_loader.tensor_file_map: elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map: if key + ".bias" in self.gguf_loader.tensor_file_map:
...@@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase): ...@@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
self.weight = None self.weight = None
self.has_bias = False self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
dtype = x.dtype dtype = x.dtype
out_device = x.device out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended. # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
...@@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase): ...@@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
if self.has_bias: if self.has_bias:
self.bias = None self.bias = None
class KLinearQ8(KLinearBase): class KLinearQ8(KLinearBase):
def __init__( def __init__(
self, self,
...@@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase): ...@@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
self.dtype = torch.get_default_dtype() self.dtype = torch.get_default_dtype()
self.block_size = block_size self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
x = x.to(self.device) x = x.to(self.device)
orig_dtype = x.dtype orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size) x_quantized, scale_x = act_quant(x, self.block_size)
...@@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase): ...@@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
self.weight = None self.weight = None
if self.has_bias: if self.has_bias:
self.bias = None self.bias = None
# TODO: merge two marlin class
class VLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
num_bits: int = 4, # 4-bit/8-bit is supported
group_size: int = 64, # -1, 32, 64, 128
act_order: bool = False,
is_k_full=True,
**kwargs,
):
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.num_bits = num_bits
self.group_size = group_size
self.act_order = act_order
self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None:
w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
# pad weight
weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False
elif isinstance(w, tuple):
w = list(w)
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1]
self.has_bias = True
else:
raise ValueError("Invalid weight type")
weight = weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order
)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.weight = marlin_q_w
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = g_idx
self.sort_indices = sort_indices
self.k = weight.shape[0]
self.n = weight.shape[1]
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
self.loaded = True
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
if bsz_tensor is None:
bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1])
marlin_s = self.marlin_s.to(x.dtype)
sms = -1
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
marlin_s,
self.g_idx,
self.sort_indices,
self.workspace.scratch,
self.num_bits,
bsz_tensor,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# x = KTransformersOps.gptq_marlin_gemm(
# x,
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
if self.has_bias:
x = x + self.bias
orig_shape[-1] = self.n
return x.reshape(orig_shape).to(orig_dtype)
def unload(self):
if self.has_bias:
self.bias = None
self.marlin_q_w = None
self.marlin_s = None
self.g_idx = None
self.sort_indices = None
self.workspace = None
class KLinearMarlin(KLinearBase): class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor marlin_q_w: torch.Tensor
marlin_s: torch.Tensor marlin_s: torch.Tensor
...@@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase): ...@@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
self.n = weight.shape[1] self.n = weight.shape[1]
self.loaded = True self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
# Only support input x as BF16 and FP16 # Only support input x as BF16 and FP16
x = x.to(self.device) x = x.to(self.device)
orig_shape = list(x.shape) orig_shape = list(x.shape)
...@@ -552,7 +699,7 @@ class KLinearCPUInfer(KLinearBase): ...@@ -552,7 +699,7 @@ class KLinearCPUInfer(KLinearBase):
self.group_max_len = group_max_len self.group_max_len = group_max_len
self.out_device = out_device self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size] origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing(): if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
out_device = x.device out_device = x.device
...@@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase): ...@@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
if self.w is not None: if self.w is not None:
self.w = None self.w = None
if self.has_bias: if self.has_bias:
self.bias = None self.bias = None
LINEAR_MAP = { LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin, "KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch, "KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer, "KLinearCPUInfer": KLinearCPUInfer,
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8, "KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8, "KLinearQ8": KLinearQ8,
} }
...@@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.generate_linear = None self.generate_linear = None
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
def forward(self, x): def forward(self, x, bsz_tensor=None):
if self.mode == InferenceState.PREFILL: if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized" assert self.prefill_linear is not None, "cpu linear is not initialized"
y = self.prefill_linear.forward(x) y = self.prefill_linear.forward(x, bsz_tensor)
else: else:
assert self.generate_linear is not None, "gpu linear is not initialized" assert self.generate_linear is not None, "gpu linear is not initialized"
y = self.generate_linear.forward(x) y = self.generate_linear.forward(x, bsz_tensor)
return y return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE): def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
...@@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.unload() self.unload()
else: else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from transformers import PretrainedConfig
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.hidden_size, orig_module.intermediate_size)
def forward(self, x, bsz_tensor):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
return down_proj
\ No newline at end of file
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
replace: replace:
class: ktransformers.operators.linear.KTransformersLinear class: ktransformers.operators.linear.KTransformersLinear
kwargs: kwargs:
generate_device: "cpu" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.RMSNorm
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace:
class: ktransformers.operators.mlp.kDeepseekV3MLP
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
\ No newline at end of file
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
- match: - match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
......
...@@ -147,7 +147,7 @@ ...@@ -147,7 +147,7 @@
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$" name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
...@@ -157,7 +157,7 @@ ...@@ -157,7 +157,7 @@
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
...@@ -167,7 +167,7 @@ ...@@ -167,7 +167,7 @@
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$" name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
...@@ -177,7 +177,7 @@ ...@@ -177,7 +177,7 @@
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$" name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
......
...@@ -278,7 +278,7 @@ ...@@ -278,7 +278,7 @@
name: "^model\\.layers\\.([0-7])\\.mlp\\.gate$" name: "^model\\.layers\\.([0-7])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
...@@ -288,7 +288,7 @@ ...@@ -288,7 +288,7 @@
name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.gate$" name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
...@@ -298,7 +298,7 @@ ...@@ -298,7 +298,7 @@
name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.gate$" name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
...@@ -308,7 +308,7 @@ ...@@ -308,7 +308,7 @@
name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.gate$" name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
...@@ -318,7 +318,7 @@ ...@@ -318,7 +318,7 @@
name: "^model\\.layers\\.(3[2-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(3[2-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:4" generate_device: "cuda:4"
prefill_device: "cuda:4" prefill_device: "cuda:4"
...@@ -328,7 +328,7 @@ ...@@ -328,7 +328,7 @@
name: "^model\\.layers\\.(4[0-7])\\.mlp\\.gate$" name: "^model\\.layers\\.(4[0-7])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:5" generate_device: "cuda:5"
prefill_device: "cuda:5" prefill_device: "cuda:5"
...@@ -338,7 +338,7 @@ ...@@ -338,7 +338,7 @@
name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.gate$" name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:6" generate_device: "cuda:6"
prefill_device: "cuda:6" prefill_device: "cuda:6"
...@@ -348,7 +348,7 @@ ...@@ -348,7 +348,7 @@
name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.gate$" name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"
......
...@@ -66,7 +66,7 @@ ...@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
......
...@@ -66,7 +66,7 @@ ...@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
...@@ -74,7 +74,7 @@ ...@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment