Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
3f9bbf11
Commit
3f9bbf11
authored
Apr 28, 2025
by
djw
Browse files
support qwen3, dont speak human language
parent
f3d842a0
Changes
30
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3470 additions
and
106 deletions
+3470
-106
ktransformers/models/configuration_qwen2_moe.py
ktransformers/models/configuration_qwen2_moe.py
+177
-0
ktransformers/models/configuration_qwen3_moe.py
ktransformers/models/configuration_qwen3_moe.py
+233
-0
ktransformers/models/custom_cache.py
ktransformers/models/custom_cache.py
+56
-0
ktransformers/models/custom_modeling_qwen2_moe.py
ktransformers/models/custom_modeling_qwen2_moe.py
+133
-0
ktransformers/models/custom_modeling_qwen3_moe.py
ktransformers/models/custom_modeling_qwen3_moe.py
+133
-0
ktransformers/models/modeling_qwen3_moe.py
ktransformers/models/modeling_qwen3_moe.py
+1472
-0
ktransformers/operators/RoPE.py
ktransformers/operators/RoPE.py
+27
-1
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+0
-89
ktransformers/operators/balance_serve_attention.py
ktransformers/operators/balance_serve_attention.py
+287
-0
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+227
-0
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
+324
-0
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+69
-0
ktransformers/operators/layernorm.py
ktransformers/operators/layernorm.py
+88
-1
ktransformers/operators/mlp.py
ktransformers/operators/mlp.py
+16
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
...rmers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
+1
-1
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
...mers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
+1
-1
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
+95
-0
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
+95
-0
ktransformers/server/args.py
ktransformers/server/args.py
+3
-1
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+33
-10
No files found.
ktransformers/models/configuration_qwen2_moe.py
0 → 100644
View file @
3f9bbf11
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2MoE model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
Qwen2MoeConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `False`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
```python
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
>>> # Initializing a Qwen2MoE style configuration
>>> configuration = Qwen2MoeConfig()
>>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration
>>> model = Qwen2MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen2_moe"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
2048
,
intermediate_size
=
5632
,
num_hidden_layers
=
24
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
use_sliding_window
=
False
,
sliding_window
=
4096
,
max_window_layers
=
28
,
attention_dropout
=
0.0
,
decoder_sparse_step
=
1
,
moe_intermediate_size
=
1408
,
shared_expert_intermediate_size
=
5632
,
num_experts_per_tok
=
4
,
num_experts
=
60
,
norm_topk_prob
=
False
,
output_router_logits
=
False
,
router_aux_loss_coef
=
0.001
,
mlp_only_layers
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
use_sliding_window
=
use_sliding_window
self
.
sliding_window
=
sliding_window
if
use_sliding_window
else
None
self
.
max_window_layers
=
max_window_layers
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_dropout
=
attention_dropout
# MoE arguments
self
.
decoder_sparse_step
=
decoder_sparse_step
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
shared_expert_intermediate_size
=
shared_expert_intermediate_size
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts
=
num_experts
self
.
norm_topk_prob
=
norm_topk_prob
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
self
.
mlp_only_layers
=
[]
if
mlp_only_layers
is
None
else
mlp_only_layers
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
\ No newline at end of file
ktransformers/models/configuration_qwen3_moe.py
0 → 100644
View file @
3f9bbf11
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3MoE model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_rope_utils
import
rope_config_validation
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
Qwen3MoeConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a
Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6144):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 768):
Intermediate size of the routed expert.
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 128):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `False`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
```python
>>> from transformers import Qwen3MoeModel, Qwen3MoeConfig
>>> # Initializing a Qwen3MoE style configuration
>>> configuration = Qwen3MoeConfig()
>>> # Initializing a model from the Qwen3-15B-A2B" style configuration
>>> model = Qwen3MoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen3_moe"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `Qwen3Moe`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
2048
,
intermediate_size
=
6144
,
num_hidden_layers
=
24
,
num_attention_heads
=
32
,
num_key_value_heads
=
4
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
use_sliding_window
=
False
,
sliding_window
=
4096
,
max_window_layers
=
28
,
attention_dropout
=
0.0
,
decoder_sparse_step
=
1
,
moe_intermediate_size
=
768
,
num_experts_per_tok
=
8
,
num_experts
=
128
,
norm_topk_prob
=
False
,
output_router_logits
=
False
,
router_aux_loss_coef
=
0.001
,
mlp_only_layers
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
use_sliding_window
=
use_sliding_window
self
.
sliding_window
=
sliding_window
if
use_sliding_window
else
None
self
.
max_window_layers
=
max_window_layers
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if
self
.
rope_scaling
is
not
None
and
"type"
in
self
.
rope_scaling
:
self
.
rope_scaling
[
"rope_type"
]
=
self
.
rope_scaling
[
"type"
]
rope_config_validation
(
self
)
# MoE arguments
self
.
decoder_sparse_step
=
decoder_sparse_step
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
num_experts
=
num_experts
self
.
norm_topk_prob
=
norm_topk_prob
self
.
output_router_logits
=
output_router_logits
self
.
router_aux_loss_coef
=
router_aux_loss_coef
self
.
mlp_only_layers
=
[]
if
mlp_only_layers
is
None
else
mlp_only_layers
super
().
__init__
(
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
__all__
=
[
"Qwen3MoeConfig"
]
\ No newline at end of file
ktransformers/models/custom_cache.py
View file @
3f9bbf11
...
@@ -275,3 +275,59 @@ class KDeepSeekV3Cache(nn.Module):
...
@@ -275,3 +275,59 @@ class KDeepSeekV3Cache(nn.Module):
return
page_idx
,
page_offset
return
page_idx
,
page_offset
class
KGQACache
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
page_size
:
int
=
256
,
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda:0"
),
):
super
().
__init__
()
self
.
config
=
config
self
.
dtype
=
dtype
self
.
device
=
device
self
.
page_size
=
page_size
self
.
k_caches
=
[]
self
.
v_caches
=
[]
def
load
(
self
,
inference_context
:
sched_ext
.
InferenceContext
):
print
(
self
.
config
.
num_hidden_layers
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
self
.
k_caches
.
append
(
inference_context
.
k_cache
[
0
][
i
]
)
self
.
v_caches
.
append
(
inference_context
.
v_cache
[
0
][
i
]
)
self
.
max_cache_len
=
self
.
k_caches
[
0
].
shape
[
0
]
*
self
.
k_caches
[
0
].
shape
[
1
]
def
get_page_table
(
self
,
cache_position
:
torch
.
Tensor
,
q_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_indices
:
torch
.
Tensor
,
bsz_tensors
:
torch
.
tensor
):
page_offset
=
cache_position
%
self
.
page_size
page_idx_local
=
cache_position
//
self
.
page_size
query_ids
=
torch
.
zeros_like
(
cache_position
)
for
i
in
range
(
len
(
q_indptr
)
-
1
):
start_idx
=
q_indptr
[
i
]
end_idx
=
q_indptr
[
i
+
1
]
query_ids
[
start_idx
:
end_idx
]
=
i
page_idx
=
torch
.
zeros_like
(
page_idx_local
)
for
i
in
range
(
bsz_tensors
[
0
]):
query_id
=
query_ids
[
i
]
local_block
=
page_idx_local
[
i
]
start_block
=
kv_indptr
[
query_id
]
if
local_block
<
kv_indptr
[
query_id
+
1
]
-
kv_indptr
[
query_id
]:
page_idx
[
i
]
=
kv_indices
[
start_block
+
local_block
]
return
page_idx
,
page_offset
def
get_k_cache
(
self
,
layer_idx
):
return
self
.
k_caches
[
layer_idx
]
def
get_v_cache
(
self
,
layer_idx
):
return
self
.
v_caches
[
layer_idx
]
\ No newline at end of file
ktransformers/models/custom_modeling_qwen2_moe.py
0 → 100644
View file @
3f9bbf11
"""
Date: 2024-11-06 10:05:11
LastEditors: djw
LastEditTime: 2024-11-13 07:50:51
"""
import
math
from
dataclasses
import
dataclass
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
ktransformers.server.balance_serve.inference.forward_batch
import
ForwardBatchInput
,
ForwardBatchOutput
from
ktransformers.models.custom_cache
import
KGQACache
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeModel
,
Qwen2MoePreTrainedModel
from
ktransformers.models.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.operators.flashinfer_batch_prefill_wrapper
import
flashInferAttn
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
import
flashinfer
class
KQwen2MoeForCausalLM
(
Qwen2MoePreTrainedModel
):
cache
:
KGQACache
use_cuda_graph
=
False
def
__init__
(
self
,
config
:
Qwen2MoeConfig
,
cache
,
):
super
().
__init__
(
config
)
self
.
model
=
Qwen2MoeModel
(
config
)
self
.
config
=
config
self
.
cache
=
cache
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
attn
=
[
None
]
*
10
def
init_wrapper
(
self
,
use_cuda_graph
,
device
,
max_batch_token
,
max_batch_size
,
max_pages
,
cuda_graph_idx
=
0
):
self
.
attn
[
cuda_graph_idx
]
=
flashInferAttn
(
use_cuda_graph
=
use_cuda_graph
,
max_batch_token
=
max_batch_token
,
max_batch_size
=
max_batch_size
,
max_pages
=
max_pages
,
device
=
device
)
def
batch_embeddings
(
self
,
batch
:
ForwardBatchInput
,
device
=
"cuda:0"
):
features
=
[]
for
i
in
range
(
batch
.
batch_size
):
tokens
=
batch
.
minibatch
.
tokens
.
contiguous
()
feature
=
(
self
.
model
.
embed_tokens
(
tokens
.
to
(
torch
.
device
(
'cpu'
)))
.
to
(
torch
.
bfloat16
)
.
to
(
device
=
device
)
)
features
.
append
(
feature
)
return
features
def
forward
(
self
,
batch
:
ForwardBatchInput
|
None
=
None
,
features
:
List
[
torch
.
Tensor
]
|
None
=
None
,
bsz_tensors
:
torch
.
Tensor
|
None
=
None
,
num_tokens_tensors
:
torch
.
Tensor
|
None
=
None
,
page_idx
:
torch
.
Tensor
|
None
=
None
,
page_offset
:
torch
.
Tensor
|
None
=
None
,
cuda_graph_idx
:
int
|
None
=
0
)
->
ForwardBatchOutput
:
current_stream
=
torch
.
cuda
.
current_stream
()
forward_batch_output
=
ForwardBatchOutput
()
hidden_states
=
features
[
0
]
self
.
attn
[
cuda_graph_idx
].
calc_batch_indices
(
hidden_states
.
shape
[
0
])
with
torch
.
cuda
.
stream
(
current_stream
):
residual
=
torch
.
zeros_like
(
hidden_states
)
for
i
,
decode_layer
in
enumerate
(
self
.
model
.
layers
):
if
self
.
model
.
transfer_map
is
not
None
and
i
in
self
.
model
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
model
.
transfer_map
[
i
]
if
cur_device
not
in
self
.
model
.
stream_device_map
:
self
.
model
.
stream_device_map
[
cur_device
]
=
torch
.
cuda
.
Stream
(
cur_device
)
torch
.
cuda
.
set_device
(
cur_device
)
self
.
model
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_stream
(
self
.
model
.
stream_device_map
[
cur_device
])
hidden_states
=
hidden_states
.
to
(
self
.
model
.
transfer_map
[
i
],
non_blocking
=
True
)
batch
.
minibatch
.
position_ids
=
(
batch
.
minibatch
.
position_ids
.
to
(
self
.
model
.
transfer_map
[
i
],
non_blocking
=
True
)
if
batch
.
minibatch
.
position_ids
is
not
None
else
None
)
hidden_states
,
residual
=
decode_layer
.
input_layernorm
(
hidden_states
,
num_tokens_tensors
,
residual
)
hidden_states
=
decode_layer
.
self_attn
(
hidden_states
,
self
.
cache
,
position_ids
=
batch
.
minibatch
.
position_ids
,
wrapper
=
self
.
attn
[
cuda_graph_idx
],
bsz_tensors
=
num_tokens_tensors
,
page_idx
=
page_idx
,
page_offset
=
page_offset
)
hidden_states
,
residual
=
decode_layer
.
post_attention_layernorm
(
hidden_states
,
num_tokens_tensors
,
residual
)
hidden_states
=
decode_layer
.
mlp
(
hidden_states
.
unsqueeze
(
0
),
num_tokens_tensors
,
cuda_graph_idx
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
forward_batch_output
=
ForwardBatchOutput
()
with
torch
.
cuda
.
stream
(
current_stream
):
local_logit
=
self
.
lm_head
(
self
.
model
.
norm
(
hidden_states
,
num_tokens_tensors
,
residual
)[
0
],
num_tokens_tensors
)
forward_batch_output
.
logits
.
append
(
local_logit
)
return
forward_batch_output
def
flash_infer_attn_plan
(
self
,
batch
:
ForwardBatchInput
,
bsz_tensors
,
num_tokens_tensors
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
page_size
:
int
,
causal
:
bool
,
q_data_type
:
torch
.
dtype
,
kv_data_type
:
torch
.
dtype
,
cuda_graph_idx
:
int
=
0
):
minibatch
=
batch
.
minibatch
self
.
attn
[
cuda_graph_idx
].
plan
(
minibatch
.
q_indptr
,
minibatch
.
kv_indptr
,
minibatch
.
kv_indices
,
minibatch
.
kv_last_page_len
,
bsz_tensors
,
num_tokens_tensors
,
num_q_heads
,
num_kv_heads
,
head_dim
,
page_size
,
causal
=
causal
,
q_data_type
=
q_data_type
,
kv_data_type
=
kv_data_type
)
\ No newline at end of file
ktransformers/models/custom_modeling_qwen3_moe.py
0 → 100644
View file @
3f9bbf11
"""
Date: 2024-11-06 10:05:11
LastEditors: djw
LastEditTime: 2024-11-13 07:50:51
"""
import
math
from
dataclasses
import
dataclass
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
ktransformers.server.balance_serve.inference.forward_batch
import
ForwardBatchInput
,
ForwardBatchOutput
from
ktransformers.models.custom_cache
import
KGQACache
from
ktransformers.models.modeling_qwen3_moe
import
Qwen3MoeModel
,
Qwen3MoePreTrainedModel
from
ktransformers.models.configuration_qwen3_moe
import
Qwen3MoeConfig
from
ktransformers.operators.flashinfer_batch_prefill_wrapper
import
flashInferAttn
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
import
flashinfer
class
KQwen3MoeForCausalLM
(
Qwen3MoePreTrainedModel
):
cache
:
KGQACache
use_cuda_graph
=
False
def
__init__
(
self
,
config
:
Qwen3MoeConfig
,
cache
=
None
,
):
super
().
__init__
(
config
)
self
.
model
=
Qwen3MoeModel
(
config
)
self
.
config
=
config
self
.
cache
=
cache
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
attn
=
[
None
]
*
10
def
init_wrapper
(
self
,
use_cuda_graph
,
device
,
max_batch_token
,
max_batch_size
,
max_pages
,
cuda_graph_idx
=
0
):
self
.
attn
[
cuda_graph_idx
]
=
flashInferAttn
(
use_cuda_graph
=
use_cuda_graph
,
max_batch_token
=
max_batch_token
,
max_batch_size
=
max_batch_size
,
max_pages
=
max_pages
,
device
=
device
)
def
batch_embeddings
(
self
,
batch
:
ForwardBatchInput
,
device
=
"cuda:0"
):
features
=
[]
for
i
in
range
(
batch
.
batch_size
):
tokens
=
batch
.
minibatch
.
tokens
.
contiguous
()
feature
=
(
self
.
model
.
embed_tokens
(
tokens
.
to
(
torch
.
device
(
'cpu'
)))
.
to
(
torch
.
bfloat16
)
.
to
(
device
=
device
)
)
features
.
append
(
feature
)
return
features
def
forward
(
self
,
batch
:
ForwardBatchInput
|
None
=
None
,
features
:
List
[
torch
.
Tensor
]
|
None
=
None
,
bsz_tensors
:
torch
.
Tensor
|
None
=
None
,
num_tokens_tensors
:
torch
.
Tensor
|
None
=
None
,
page_idx
:
torch
.
Tensor
|
None
=
None
,
page_offset
:
torch
.
Tensor
|
None
=
None
,
cuda_graph_idx
:
int
|
None
=
0
)
->
ForwardBatchOutput
:
current_stream
=
torch
.
cuda
.
current_stream
()
forward_batch_output
=
ForwardBatchOutput
()
hidden_states
=
features
[
0
]
self
.
attn
[
cuda_graph_idx
].
calc_batch_indices
(
hidden_states
.
shape
[
0
])
with
torch
.
cuda
.
stream
(
current_stream
):
residual
=
torch
.
zeros_like
(
hidden_states
)
for
i
,
decode_layer
in
enumerate
(
self
.
model
.
layers
):
if
self
.
model
.
transfer_map
is
not
None
and
i
in
self
.
model
.
transfer_map
:
prev_stream
=
torch
.
cuda
.
current_stream
()
cur_device
=
self
.
model
.
transfer_map
[
i
]
if
cur_device
not
in
self
.
model
.
stream_device_map
:
self
.
model
.
stream_device_map
[
cur_device
]
=
torch
.
cuda
.
Stream
(
cur_device
)
torch
.
cuda
.
set_device
(
cur_device
)
self
.
model
.
stream_device_map
[
cur_device
].
wait_stream
(
prev_stream
)
torch
.
cuda
.
set_stream
(
self
.
model
.
stream_device_map
[
cur_device
])
hidden_states
=
hidden_states
.
to
(
self
.
model
.
transfer_map
[
i
],
non_blocking
=
True
)
batch
.
minibatch
.
position_ids
=
(
batch
.
minibatch
.
position_ids
.
to
(
self
.
model
.
transfer_map
[
i
],
non_blocking
=
True
)
if
batch
.
minibatch
.
position_ids
is
not
None
else
None
)
hidden_states
,
residual
=
decode_layer
.
input_layernorm
(
hidden_states
,
num_tokens_tensors
,
residual
)
hidden_states
=
decode_layer
.
self_attn
(
hidden_states
,
self
.
cache
,
position_ids
=
batch
.
minibatch
.
position_ids
,
wrapper
=
self
.
attn
[
cuda_graph_idx
],
bsz_tensors
=
num_tokens_tensors
,
page_idx
=
page_idx
,
page_offset
=
page_offset
)
hidden_states
,
residual
=
decode_layer
.
post_attention_layernorm
(
hidden_states
,
num_tokens_tensors
,
residual
)
hidden_states
=
decode_layer
.
mlp
(
hidden_states
.
unsqueeze
(
0
),
num_tokens_tensors
,
cuda_graph_idx
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
forward_batch_output
=
ForwardBatchOutput
()
with
torch
.
cuda
.
stream
(
current_stream
):
local_logit
=
self
.
lm_head
(
self
.
model
.
norm
(
hidden_states
,
num_tokens_tensors
,
residual
)[
0
],
num_tokens_tensors
)
forward_batch_output
.
logits
.
append
(
local_logit
)
return
forward_batch_output
def
flash_infer_attn_plan
(
self
,
batch
:
ForwardBatchInput
,
bsz_tensors
,
num_tokens_tensors
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
page_size
:
int
,
causal
:
bool
,
q_data_type
:
torch
.
dtype
,
kv_data_type
:
torch
.
dtype
,
cuda_graph_idx
:
int
=
0
):
minibatch
=
batch
.
minibatch
self
.
attn
[
cuda_graph_idx
].
plan
(
minibatch
.
q_indptr
,
minibatch
.
kv_indptr
,
minibatch
.
kv_indices
,
minibatch
.
kv_last_page_len
,
bsz_tensors
,
num_tokens_tensors
,
num_q_heads
,
num_kv_heads
,
head_dim
,
page_size
,
causal
=
causal
,
q_data_type
=
q_data_type
,
kv_data_type
=
kv_data_type
)
\ No newline at end of file
ktransformers/models/modeling_qwen3_moe.py
0 → 100644
View file @
3f9bbf11
This diff is collapsed.
Click to expand it.
ktransformers/operators/RoPE.py
View file @
3f9bbf11
...
@@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule):
...
@@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule):
self
.
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
self
.
dim
))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
# For BC we register cos and sin cached
self
.
max_seq_len_cached
=
max_position_embeddings
self
.
max_seq_len_cached
=
max_position_embeddings
\ No newline at end of file
class
KQwen3MoeRotaryEmbedding
(
BaseInjectedModule
,
DeepseekV2RotaryEmbedding
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
config
,
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
def
load
(
self
):
self
.
orig_module
.
__init__
(
self
.
orig_module
.
config
)
\ No newline at end of file
ktransformers/operators/attention.py
View file @
3f9bbf11
...
@@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule):
...
@@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights
=
None
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
return
attn_output
,
attn_weights
,
past_key_value
class
flashinfer_attn
(
BaseInjectedModule
,
DeepseekV2Attention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
#del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KDeepSeekV3Cache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
BatchMLAPagedAttentionWrapper
,
num_tokens_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
):
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
,
num_tokens_tensors
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
,
num_tokens_tensors
),
num_tokens_tensors
),
num_tokens_tensors
)
q
=
q
.
view
(
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
num_tokens_tensors
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
compressed_kv
.
contiguous
()
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
,
num_tokens_tensors
)
k_pe
=
k_pe
.
view
(
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
.
unsqueeze
(
0
))
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
.
unsqueeze
(
0
),
k_pe
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
=
q_pe
.
squeeze
(
0
)
if
kv_cache
is
not
None
:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"page_idx"
:
page_idx
,
"page_offset"
:
page_offset
}
# Specific to RoPE models
compressed_kv_with_k_pe
=
kv_cache
.
update
(
compressed_kv
.
unsqueeze
(
0
),
k_pe
,
self
.
layer_idx
,
page_idx
,
page_offset
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
qk_rope_head_dim
)
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
attn_output
=
attn_output
.
transpose
(
0
,
1
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# [self.num_heads, q_len, self.v_head_dim]
attn_output
=
attn_output
.
transpose
(
0
,
1
)
attn_output
=
attn_output
.
reshape
(
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
,
num_tokens_tensors
)
return
attn_output
ktransformers/operators/balance_serve_attention.py
0 → 100644
View file @
3f9bbf11
'''
Description :
Author : Boxin Zhang
Version : 0.2.5
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import
torch
from
torch
import
nn
from
ktransformers.models.modeling_deepseek
import
DeepseekV2Attention
,
apply_rotary_pos_emb
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeAttention
from
ktransformers.models.modeling_qwen3_moe
import
Qwen3MoeAttention
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
flashinfer
import
BatchMLAPagedAttentionWrapper
from
ktransformers.operators.flashinfer_batch_prefill_wrapper
import
flashInferAttn
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
,
KGQACache
logger
=
logging
.
getLogger
(
"attention"
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
class
flashinfer_attn
(
BaseInjectedModule
,
DeepseekV2Attention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
#del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KDeepSeekV3Cache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
BatchMLAPagedAttentionWrapper
,
num_tokens_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
):
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
,
num_tokens_tensors
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
,
num_tokens_tensors
),
num_tokens_tensors
),
num_tokens_tensors
)
q
=
q
.
view
(
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
num_tokens_tensors
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
compressed_kv
.
contiguous
()
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
,
num_tokens_tensors
)
k_pe
=
k_pe
.
view
(
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
.
unsqueeze
(
0
))
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
.
unsqueeze
(
0
),
k_pe
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
=
q_pe
.
squeeze
(
0
)
if
kv_cache
is
not
None
:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"page_idx"
:
page_idx
,
"page_offset"
:
page_offset
}
# Specific to RoPE models
compressed_kv_with_k_pe
=
kv_cache
.
update
(
compressed_kv
.
unsqueeze
(
0
),
k_pe
,
self
.
layer_idx
,
page_idx
,
page_offset
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
qk_rope_head_dim
)
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
attn_output
=
attn_output
.
transpose
(
0
,
1
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# [self.num_heads, q_len, self.v_head_dim]
attn_output
=
attn_output
.
transpose
(
0
,
1
)
attn_output
=
attn_output
.
reshape
(
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
,
num_tokens_tensors
)
return
attn_output
class
KQwen2MoeAttention
(
BaseInjectedModule
,
Qwen2MoeAttention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
self
,
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KGQACache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
flashInferAttn
,
bsz_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
):
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
,
bsz_tensors
)
key_states
=
self
.
k_proj
(
hidden_states
,
bsz_tensors
)
value_states
=
self
.
v_proj
(
hidden_states
,
bsz_tensors
)
query_states
=
query_states
.
view
(
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
.
unsqueeze
(
0
),
position_ids
.
unsqueeze
(
0
))
query_states
,
key_states
=
self
.
apply_rotary_pos_emb
(
query_states
.
unsqueeze
(
0
),
key_states
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
query_states
=
query_states
.
view
(
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
k_cache
=
kv_cache
.
get_k_cache
(
self
.
layer_idx
)
v_cache
=
kv_cache
.
get_v_cache
(
self
.
layer_idx
)
attn_output
=
wrapper
.
forward
(
query_states
,
k_cache
,
v_cache
,
key_states
,
value_states
)
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
q_len
,
self
.
num_heads
*
self
.
head_dim
),
bsz_tensors
)
return
attn_output
class
KQwen3MoeAttention
(
BaseInjectedModule
,
Qwen3MoeAttention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
self
,
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KGQACache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
flashInferAttn
,
bsz_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
):
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_norm
(
self
.
q_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors
)
key_states
=
self
.
k_norm
(
self
.
k_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors
)
value_states
=
self
.
v_proj
(
hidden_states
,
bsz_tensors
)
query_states
=
query_states
.
view
(
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
.
unsqueeze
(
0
),
position_ids
.
unsqueeze
(
0
))
query_states
,
key_states
=
self
.
apply_rotary_pos_emb
(
query_states
.
unsqueeze
(
0
),
key_states
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
query_states
=
query_states
.
view
(
q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
k_cache
=
kv_cache
.
get_k_cache
(
self
.
layer_idx
)
v_cache
=
kv_cache
.
get_v_cache
(
self
.
layer_idx
)
attn_output
=
wrapper
.
forward
(
query_states
,
k_cache
,
v_cache
,
key_states
,
value_states
)
attn_output
=
self
.
o_proj
(
attn_output
.
view
(
q_len
,
self
.
num_heads
*
self
.
head_dim
),
bsz_tensors
)
return
attn_output
ktransformers/operators/experts.py
View file @
3f9bbf11
...
@@ -689,6 +689,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
...
@@ -689,6 +689,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseek
import
DeepseekV2MoE
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3MoE
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3MoE
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeSparseMoeBlock
from
ktransformers.models.modeling_qwen3_moe
import
Qwen3MoeSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
from
ktransformers.models.modeling_mixtral
import
MixtralSparseMoeBlock
...
@@ -1267,3 +1268,229 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
...
@@ -1267,3 +1268,229 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
self
.
unload
()
self
.
unload
()
else
:
else
:
raise
ValueError
(
"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD"
)
raise
ValueError
(
"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD"
)
class
KQwen2MoeSparseMoeBlockV2
(
BaseInjectedModule
,
Qwen2MoeSparseMoeBlock
):
def
forward
(
self
,
hidden_states
,
bsz_tensor
,
cuda_graph_idx
=
0
):
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
router_logits
=
self
.
gate
(
hidden_states
,
bsz_tensor
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# we cast back to the input dtype
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
# only for generate phase
if
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
# TODO: this branch cause jit bug
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
,
selected_experts
,
routing_weights
,
bsz_tensor
,
cuda_graph_idx
)
y_
=
self
.
shared_expert
(
hidden_states
,
bsz_tensor
).
squeeze
(
0
)
y_
=
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
y_
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
(
cuda_graph_idx
).
unsqueeze
(
0
)
y
+=
y_
y
.
resize_
(
*
orig_shape
)
return
y
y_
=
self
.
shared_expert
(
hidden_states
,
bsz_tensor
).
squeeze
(
0
)
y_
=
(
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
y_
)
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_on_cpuinfer
(
hidden_states
,
selected_experts
,
routing_weights
,
bsz_tensor
,
cuda_graph_idx
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
y
=
(
self
.
moe_infer
(
hidden_states
,
selected_experts
,
routing_weights
)
.
view
(
*
orig_shape
)
.
to
(
device
=
hidden_states
.
device
)
)
else
:
# TODO may bugs here
y
=
(
self
.
moe_infer_simple
(
hidden_states
,
selected_experts
,
routing_weights
)
.
view
(
*
orig_shape
)
.
to
(
device
=
hidden_states
.
device
)
)
y
+=
y_
return
y
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
bsz_tensor
,
cuda_graph_idx
=
0
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
,
bsz_tensor
,
cuda_graph_idx
)
return
outs
@
torch
.
no_grad
()
# TODO may bugs here
def
moe_infer_simple
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs
=
torch
.
zeros_like
(
x
)
for
token_idx
in
range
(
topk_ids
.
size
(
0
)):
for
expert_idx
in
range
(
topk_ids
.
size
(
1
)):
expert
=
self
.
experts
[
topk_ids
[
token_idx
,
expert_idx
]]
outs
[
token_idx
]
+=
(
expert
.
forward
(
x
[
token_idx
])
*
topk_weight
[
token_idx
,
expert_idx
]
)
return
outs
@
torch
.
no_grad
()
# TODO may bugs here
def
moe_infer
(
self
,
x
,
topk_ids
,
topk_weight
):
cnts
=
topk_ids
.
new_zeros
((
topk_ids
.
shape
[
0
],
len
(
self
.
experts
)))
cnts
.
scatter_
(
1
,
topk_ids
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
topk_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
expert
=
self
.
experts
[
i
+
self
.
ep_rank
*
self
.
experts_per_rank
]
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
expert_out
=
expert
.
forward
(
tokens_for_this_expert
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weight
.
dtype
)
.
mul_
(
topk_weight
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
final_out
class
KQwen3MoeSparseMoeBlockV2
(
BaseInjectedModule
,
Qwen3MoeSparseMoeBlock
):
def
forward
(
self
,
hidden_states
,
bsz_tensor
,
cuda_graph_idx
=
0
):
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
router_logits
=
self
.
gate
(
hidden_states
,
bsz_tensor
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# we cast back to the input dtype
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
# only for generate phase
if
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
# TODO: this branch cause jit bug
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
,
selected_experts
,
routing_weights
,
bsz_tensor
,
cuda_graph_idx
)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
(
cuda_graph_idx
).
unsqueeze
(
0
)
# y += y_
y
.
resize_
(
*
orig_shape
)
return
y
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = (
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
# )
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_on_cpuinfer
(
hidden_states
,
selected_experts
,
routing_weights
,
bsz_tensor
,
cuda_graph_idx
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
y
=
(
self
.
moe_infer
(
hidden_states
,
selected_experts
,
routing_weights
)
.
view
(
*
orig_shape
)
.
to
(
device
=
hidden_states
.
device
)
)
else
:
# TODO may bugs here
y
=
(
self
.
moe_infer_simple
(
hidden_states
,
selected_experts
,
routing_weights
)
.
view
(
*
orig_shape
)
.
to
(
device
=
hidden_states
.
device
)
)
# y += y_
return
y
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
bsz_tensor
,
cuda_graph_idx
=
0
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
,
bsz_tensor
,
cuda_graph_idx
)
return
outs
@
torch
.
no_grad
()
# TODO may bugs here
def
moe_infer_simple
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs
=
torch
.
zeros_like
(
x
)
for
token_idx
in
range
(
topk_ids
.
size
(
0
)):
for
expert_idx
in
range
(
topk_ids
.
size
(
1
)):
expert
=
self
.
experts
[
topk_ids
[
token_idx
,
expert_idx
]]
outs
[
token_idx
]
+=
(
expert
.
forward
(
x
[
token_idx
])
*
topk_weight
[
token_idx
,
expert_idx
]
)
return
outs
@
torch
.
no_grad
()
# TODO may bugs here
def
moe_infer
(
self
,
x
,
topk_ids
,
topk_weight
):
cnts
=
topk_ids
.
new_zeros
((
topk_ids
.
shape
[
0
],
len
(
self
.
experts
)))
cnts
.
scatter_
(
1
,
topk_ids
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
topk_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
expert
=
self
.
experts
[
i
+
self
.
ep_rank
*
self
.
experts_per_rank
]
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
expert_out
=
expert
.
forward
(
tokens_for_this_expert
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weight
.
dtype
)
.
mul_
(
topk_weight
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
final_out
\ No newline at end of file
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
0 → 100644
View file @
3f9bbf11
import
torch
import
flashinfer
import
gc
try
:
from
flash_attn
import
flash_attn_with_kvcache
print
(
"found flash_attn"
)
except
ImportError
:
print
(
"flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this."
)
from
typing
import
Union
,
Optional
def
setup_seed
(
seed
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
setup_seed
(
998244353
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
global_dtype
=
torch
.
bfloat16
global_device
=
torch
.
device
(
"cuda"
,
0
)
torch
.
cuda
.
set_device
(
0
)
torch
.
backends
.
cudnn
.
enabled
=
True
torch
.
backends
.
cudnn
.
benchmark
=
True
class
flashInferAttn
():
float_workspace_buffer
=
None
def
__init__
(
self
,
max_batch_token
,
max_batch_size
,
max_pages
,
device
=
"cuda:0"
,
kv_layout
:
str
=
"NHD"
,
use_cuda_graph
:
bool
=
False
,
)
->
None
:
self
.
device
=
device
self
.
max_batch_token
=
max_batch_token
self
.
kv_layout
=
kv_layout
self
.
use_cuda_graph
=
use_cuda_graph
if
flashInferAttn
.
float_workspace_buffer
is
None
:
flashInferAttn
.
float_workspace_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
device
)
self
.
qo_indptr_buf
=
torch
.
empty
((
max_batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_indptr_buf
=
torch
.
empty
((
max_batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_indices_buf
=
torch
.
empty
((
max_pages
,),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
paged_kv_last_page_len_buf
=
torch
.
empty
((
max_batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
batch_size_tensor_buf
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
num_tokens_tensor_buf
=
torch
.
empty
((
1
,),
dtype
=
torch
.
uint32
,
device
=
device
)
# TODO: custom mask
self
.
custom_mask_buf
=
None
self
.
qk_indptr_buf
=
None
self
.
warpper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
flashInferAttn
.
float_workspace_buffer
,
self
.
kv_layout
,
use_cuda_graph
=
self
.
use_cuda_graph
,
qo_indptr_buf
=
self
.
qo_indptr_buf
,
paged_kv_indptr_buf
=
self
.
paged_kv_indptr_buf
,
paged_kv_indices_buf
=
self
.
paged_kv_indices_buf
,
paged_kv_last_page_len_buf
=
self
.
paged_kv_last_page_len_buf
,
backend
=
"fa2"
,
)
def
plan
(
self
,
qo_indptr
:
torch
.
Tensor
,
paged_kv_indptr
:
torch
.
Tensor
,
paged_kv_indices
:
torch
.
Tensor
,
paged_kv_last_page_len
:
torch
.
Tensor
,
batch_size_tensor
:
torch
.
Tensor
,
num_tokens_tensor
:
torch
.
Tensor
,
num_qo_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
page_size
:
int
,
causal
:
bool
=
True
,
pos_encoding_mode
:
str
=
"NONE"
,
q_data_type
:
Union
[
str
,
torch
.
dtype
]
=
torch
.
bfloat16
,
kv_data_type
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
):
self
.
batch_size_tensor_buf
.
copy_
(
batch_size_tensor
,
non_blocking
=
True
)
self
.
num_tokens_tensor_buf
.
copy_
(
num_tokens_tensor
,
non_blocking
=
True
)
self
.
page_size
=
page_size
self
.
warpper
.
plan
(
qo_indptr
,
paged_kv_indptr
,
paged_kv_indices
,
paged_kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
page_size
,
causal
=
causal
,
pos_encoding_mode
=
pos_encoding_mode
,
q_data_type
=
q_data_type
,
kv_data_type
=
kv_data_type
)
def
calc_batch_indices
(
self
,
ragged_size
=
None
):
if
self
.
use_cuda_graph
:
self
.
batch_indices
,
self
.
positions
=
flashinfer
.
get_batch_indices_positions
(
self
.
qo_indptr_buf
,
flashinfer
.
get_seq_lens
(
self
.
paged_kv_indptr_buf
,
self
.
paged_kv_last_page_len_buf
,
self
.
page_size
),
self
.
batch_size_tensor_buf
,
self
.
max_batch_token
)
else
:
self
.
batch_indices
,
self
.
positions
=
flashinfer
.
get_batch_indices_positions
(
self
.
warpper
.
_qo_indptr_buf
,
flashinfer
.
get_seq_lens
(
self
.
warpper
.
_paged_kv_indptr_buf
,
self
.
warpper
.
_paged_kv_last_page_len_buf
,
self
.
page_size
),
self
.
batch_size_tensor_buf
,
ragged_size
)
def
forward
(
self
,
q
,
k_cache
,
v_cache
,
k
,
v
):
if
self
.
use_cuda_graph
:
flashinfer
.
page
.
append_paged_kv_cache
(
k
,
v
,
self
.
batch_indices
,
self
.
positions
,
(
k_cache
,
v_cache
),
self
.
paged_kv_indices_buf
,
self
.
paged_kv_indptr_buf
,
self
.
paged_kv_last_page_len_buf
,
self
.
num_tokens_tensor_buf
)
return
self
.
warpper
.
run
(
q
,
(
k_cache
,
v_cache
))
else
:
flashinfer
.
page
.
append_paged_kv_cache
(
k
,
v
,
self
.
batch_indices
,
self
.
positions
,
(
k_cache
,
v_cache
),
self
.
warpper
.
_paged_kv_indices_buf
,
self
.
warpper
.
_paged_kv_indptr_buf
,
self
.
warpper
.
_paged_kv_last_page_len_buf
,
self
.
num_tokens_tensor_buf
)
return
self
.
warpper
.
run
(
q
,
(
k_cache
,
v_cache
))
def
testCudaGraph
():
# use max batch to create buffer
batch_decode
=
8
prefill_chunk
=
48
past_kv_0
=
4090
past_kv_1
=
4096
raged_size
=
prefill_chunk
+
batch_decode
num_key_value_heads
=
8
head_dim
=
128
num_attention_heads
=
64
page_size
=
256
num_pages_per_seq
=
(
past_kv_1
+
page_size
-
1
)
//
page_size
total_num_pages
=
(
num_pages_per_seq
+
1
)
*
(
batch_decode
+
1
)
+
prefill_chunk
//
page_size
attn
=
flashInferAttn
(
raged_size
,
batch_decode
+
1
,
total_num_pages
,
use_cuda_graph
=
True
)
batch_size_tensor
=
torch
.
tensor
([
batch_decode
+
1
],
device
=
global_device
,
dtype
=
torch
.
int32
)
k_caches
=
[]
v_caches
=
[]
ks
=
[]
vs
=
[]
qs
=
[]
for
layer_idx
in
range
(
3
):
k_caches
.
append
(
torch
.
randn
(
total_num_pages
,
page_size
,
num_key_value_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
v_caches
.
append
(
torch
.
randn
(
total_num_pages
,
page_size
,
num_key_value_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
ks
.
append
(
torch
.
randn
(
raged_size
,
num_key_value_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
vs
.
append
(
torch
.
randn
(
raged_size
,
num_key_value_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
qs
.
append
(
torch
.
randn
(
raged_size
,
num_attention_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
# warmup and capture small batch
past_kv_0
=
250
past_kv_1
=
256
num_pages_per_seq
=
(
past_kv_1
+
page_size
-
1
)
//
page_size
total_num_pages
=
(
num_pages_per_seq
+
1
)
*
(
batch_decode
+
1
)
+
prefill_chunk
//
page_size
q_indptr
=
torch
.
empty
((
batch_decode
+
2
,),
dtype
=
torch
.
int32
,
device
=
global_device
)
q_indptr
[
0
]
=
0
q_indptr
[
1
:]
=
torch
.
arange
(
prefill_chunk
,
prefill_chunk
+
batch_decode
+
1
,
device
=
global_device
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
arange
(
0
,
batch_decode
+
2
,
device
=
global_device
,
dtype
=
torch
.
int32
)
*
num_pages_per_seq
kv_indices
=
torch
.
arange
(
0
,
total_num_pages
,
device
=
global_device
,
dtype
=
torch
.
int32
)
kv_last_page_len
=
torch
.
empty
((
batch_decode
+
1
,),
dtype
=
torch
.
int32
,
device
=
global_device
)
kv_last_page_len
[:
1
+
batch_decode
//
2
]
=
int
((
past_kv_0
-
1
)
%
page_size
+
1
)
kv_last_page_len
[
1
+
batch_decode
//
2
:]
=
int
((
past_kv_1
-
1
)
%
page_size
+
1
)
print
(
q_indptr
)
print
(
kv_indptr
)
print
(
kv_indices
)
print
(
kv_last_page_len
)
attn
.
plan
(
q_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
batch_size_tensor
,
num_attention_heads
,
num_key_value_heads
,
head_dim
,
page_size
,
causal
=
True
,
pos_encoding_mode
=
"NONE"
,
q_data_type
=
torch
.
bfloat16
)
attn
.
calc_batch_indices
(
raged_size
)
for
layer_idx
in
range
(
3
):
attn
.
forward
(
qs
[
layer_idx
],
k_caches
[
layer_idx
],
v_caches
[
layer_idx
],
ks
[
layer_idx
],
vs
[
layer_idx
])
torch
.
cuda
.
synchronize
()
outs
=
[]
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
for
layer_idx
in
range
(
3
):
outs
.
append
(
attn
.
forward
(
qs
[
layer_idx
],
k_caches
[
layer_idx
],
v_caches
[
layer_idx
],
ks
[
layer_idx
],
vs
[
layer_idx
]))
g
.
replay
()
kv_last_page_len
[:
1
+
batch_decode
//
2
]
=
int
(
past_kv_0
)
kv_last_page_len
[
1
+
batch_decode
//
2
:]
=
int
(
past_kv_1
)
for
layer_idx
in
range
(
3
):
for
i
in
range
(
batch_decode
+
1
):
qi
=
qs
[
layer_idx
][
q_indptr
[
i
]
:
q_indptr
[
i
+
1
]]
o_ref_i
=
flash_attn_with_kvcache
(
qi
.
unsqueeze
(
0
),
k_caches
[
layer_idx
],
v_caches
[
layer_idx
],
causal
=
True
,
block_table
=
kv_indices
[
kv_indptr
[
i
]:
kv_indptr
[
i
+
1
]].
unsqueeze
(
0
),
cache_seqlens
=
torch
.
tensor
([
past_kv_0
if
i
<
1
+
batch_decode
//
2
else
past_kv_1
],
device
=
global_device
,
dtype
=
torch
.
int32
)
)
o_i
=
outs
[
layer_idx
][
q_indptr
[
i
]
:
q_indptr
[
i
+
1
]]
print
(
layer_idx
,
i
)
torch
.
testing
.
assert_close
(
o_i
.
unsqueeze
(
0
),
o_ref_i
,
rtol
=
5e-3
,
atol
=
5e-3
)
# run another batch size use capture cuda graph
past_kv_0
=
4090
past_kv_1
=
4096
prefill_chunk
=
24
batch_decode
=
4
num_pages_per_seq
=
(
past_kv_1
+
page_size
-
1
)
//
page_size
total_num_pages
=
(
num_pages_per_seq
+
1
)
*
(
batch_decode
+
1
)
+
prefill_chunk
//
page_size
batch_size_tensor
=
torch
.
tensor
([
batch_decode
+
1
],
device
=
global_device
,
dtype
=
torch
.
int32
)
num_tokens_tensor
=
torch
.
tensor
([
batch_decode
+
prefill_chunk
],
device
=
global_device
,
dtype
=
torch
.
int32
)
q_indptr
=
torch
.
empty
((
batch_decode
+
2
,),
dtype
=
torch
.
int32
,
device
=
global_device
)
q_indptr
[
0
]
=
0
q_indptr
[
1
:]
=
torch
.
arange
(
prefill_chunk
,
prefill_chunk
+
batch_decode
+
1
,
device
=
global_device
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
arange
(
0
,
batch_decode
+
2
,
device
=
global_device
,
dtype
=
torch
.
int32
)
*
num_pages_per_seq
kv_indices
=
torch
.
arange
(
0
,
total_num_pages
,
device
=
global_device
,
dtype
=
torch
.
int32
)
kv_last_page_len
=
torch
.
empty
((
batch_decode
+
1
,),
dtype
=
torch
.
int32
,
device
=
global_device
)
kv_last_page_len
[:
1
+
batch_decode
//
2
]
=
int
((
past_kv_0
-
1
)
%
page_size
+
1
)
kv_last_page_len
[
1
+
batch_decode
//
2
:]
=
int
((
past_kv_1
-
1
)
%
page_size
+
1
)
attn
.
plan
(
q_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
batch_size_tensor
,
num_attention_heads
,
num_key_value_heads
,
head_dim
,
page_size
,
causal
=
True
,
pos_encoding_mode
=
"NONE"
,
q_data_type
=
torch
.
bfloat16
)
attn
.
calc_batch_indices
(
raged_size
)
g
.
replay
()
kv_last_page_len
[:
1
+
batch_decode
//
2
]
=
int
(
past_kv_0
)
kv_last_page_len
[
1
+
batch_decode
//
2
:]
=
int
(
past_kv_1
)
for
layer_idx
in
range
(
3
):
for
i
in
range
(
batch_decode
+
1
):
qi
=
qs
[
layer_idx
][
q_indptr
[
i
]
:
q_indptr
[
i
+
1
]]
o_ref_i
=
flash_attn_with_kvcache
(
qi
.
unsqueeze
(
0
),
k_caches
[
layer_idx
],
v_caches
[
layer_idx
],
causal
=
True
,
block_table
=
kv_indices
[
kv_indptr
[
i
]:
kv_indptr
[
i
+
1
]].
unsqueeze
(
0
),
cache_seqlens
=
torch
.
tensor
([
past_kv_0
if
i
<
1
+
batch_decode
//
2
else
past_kv_1
],
device
=
global_device
,
dtype
=
torch
.
int32
)
)
o_i
=
outs
[
layer_idx
][
q_indptr
[
i
]
:
q_indptr
[
i
+
1
]]
print
(
layer_idx
,
i
)
torch
.
testing
.
assert_close
(
o_i
.
unsqueeze
(
0
),
o_ref_i
,
rtol
=
5e-3
,
atol
=
5e-3
)
def
testAttentionFlashInfer
(
):
batch_decode
=
32
prefill_chunk
=
64
past_kv_0
=
510
past_kv_1
=
512
raged_size
=
prefill_chunk
+
batch_decode
num_key_value_heads
=
8
head_dim
=
128
num_attention_heads
=
64
cases
=
1
page_size
=
32
num_pages_per_seq
=
(
past_kv_1
+
page_size
-
1
)
//
page_size
total_num_pages
=
(
num_pages_per_seq
+
1
)
*
(
batch_decode
+
1
)
+
prefill_chunk
//
page_size
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda:0"
)
qs
=
[]
kvs
=
[]
q_indptrs
=
[]
kv_indptrs
=
[]
kv_indicess
=
[]
kv_last_page_lens
=
[]
wrappers
=
[]
for
case_id
in
range
(
cases
):
kvs
.
append
(
torch
.
randn
(
total_num_pages
,
2
,
page_size
,
num_key_value_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
qs
.
append
(
torch
.
randn
(
raged_size
,
num_attention_heads
,
head_dim
,
device
=
global_device
,
dtype
=
torch
.
bfloat16
))
q_indptr
=
torch
.
empty
((
batch_decode
+
2
,),
dtype
=
torch
.
int32
,
device
=
global_device
)
q_indptr
[
0
]
=
0
q_indptr
[
1
:]
=
torch
.
arange
(
prefill_chunk
,
prefill_chunk
+
batch_decode
+
1
,
device
=
global_device
,
dtype
=
torch
.
int32
)
q_indptrs
.
append
(
q_indptr
)
kv_indptrs
.
append
(
torch
.
arange
(
0
,
batch_decode
+
2
,
device
=
global_device
,
dtype
=
torch
.
int32
)
*
num_pages_per_seq
)
kv_indicess
.
append
(
torch
.
arange
(
0
,
total_num_pages
,
device
=
global_device
,
dtype
=
torch
.
int32
))
kv_last_page_len
=
torch
.
empty
((
batch_decode
+
1
,),
dtype
=
torch
.
int32
,
device
=
global_device
)
kv_last_page_len
[:
1
+
batch_decode
//
2
]
=
int
((
past_kv_0
-
1
)
%
page_size
+
1
)
kv_last_page_len
[
1
+
batch_decode
//
2
:]
=
int
((
past_kv_1
-
1
)
%
page_size
+
1
)
kv_last_page_lens
.
append
(
kv_last_page_len
)
wrappers
.
append
(
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
qo_indptr_buf
=
q_indptrs
[
case_id
],
paged_kv_indptr_buf
=
kv_indptrs
[
case_id
],
paged_kv_indices_buf
=
kv_indicess
[
case_id
],
paged_kv_last_page_len_buf
=
kv_last_page_lens
[
case_id
],
))
wrappers
[
case_id
].
plan
(
q_indptrs
[
case_id
],
kv_indptrs
[
case_id
],
kv_indicess
[
case_id
],
kv_last_page_lens
[
case_id
],
num_attention_heads
,
num_key_value_heads
,
head_dim
,
page_size
,
causal
=
True
,
pos_encoding_mode
=
"ROPE_LLAMA"
,
q_data_type
=
torch
.
bfloat16
)
def
custom_forward
(
case_id
):
out
=
wrappers
[
case_id
].
run
(
qs
[
case_id
],
kvs
[
case_id
])
custom_forward
(
0
)
# testCudaGraph()
# pass
\ No newline at end of file
ktransformers/operators/gate.py
View file @
3f9bbf11
...
@@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
...
@@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self
.
e_score_correction_bias
=
None
self
.
e_score_correction_bias
=
None
class
KMoEGateQwen2Moe
(
BaseInjectedModule
,
KMoEGateBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
generate_device
:
str
=
"cuda"
,
generate_op
:
str
|
None
=
"KLinearMarlin"
,
prefill_device
:
str
=
"cuda"
,
prefill_op
:
str
|
None
=
"KLinearMarlin"
,
use_quant
:
bool
=
False
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
KMoEGateBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
self
.
generate_op
=
generate_op
self
.
prefill_op
=
prefill_op
self
.
is_windows
=
os
.
name
==
'nt'
self
.
use_quant
=
use_quant
if
not
self
.
is_windows
and
use_quant
:
self
.
gate_linear
=
nn
.
Linear
(
self
.
gating_dim
,
self
.
n_routed_experts
,
device
=
generate_device
)
self
.
gate_linear
=
KTransformersLinear
(
key
+
".ffn_gate_inp"
,
gguf_loader
,
config
,
self
.
gate_linear
,
#orig_module
generate_device
,
generate_op
,
prefill_device
,
prefill_op
)
else
:
self
.
gate_linear
=
None
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
if
self
.
is_windows
:
return
self
.
orig_module
.
forward
(
hidden_states
)
bsz
,
seq_len
,
h
=
hidden_states
.
shape
### compute gating score
hidden_states
=
hidden_states
.
view
(
-
1
,
h
)
if
self
.
use_quant
:
logits
=
self
.
gate_linear
.
forward
(
logits
)
else
:
logits
=
F
.
linear
(
hidden_states
.
type
(
torch
.
float32
),
self
.
weight
.
type
(
torch
.
float32
),
None
)
return
grouped_topk
(
hidden_states
,
logits
,
self
.
top_k
,
self
.
norm_topk_prob
,
self
.
n_group
,
self
.
topk_group
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weights
(
device
=
device
)
if
isinstance
(
w
,
dict
):
self
.
weight_type
=
w
[
"weight_type"
]
self
.
e_score_correction_bias_type
=
w
[
"e_score_correction_bias_type"
]
self
.
orig_module
.
weight
=
nn
.
Parameter
(
w
[
"weight"
])
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
))
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
))
if
not
self
.
is_windows
and
self
.
use_quant
:
self
.
gate_linear
.
load
(
self
.
orig_module
.
weight
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
\ No newline at end of file
ktransformers/operators/layernorm.py
View file @
3f9bbf11
...
@@ -26,6 +26,8 @@ from transformers import PretrainedConfig
...
@@ -26,6 +26,8 @@ from transformers import PretrainedConfig
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3RMSNorm
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3RMSNorm
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeRMSNorm
from
ktransformers.models.modeling_qwen3_moe
import
Qwen3MoeRMSNorm
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
flashinfer.norm
import
(
from
flashinfer.norm
import
(
...
@@ -75,4 +77,89 @@ class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
...
@@ -75,4 +77,89 @@ class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
\ No newline at end of file
class
KQwen2MoeRMSNorm
(
Qwen2MoeRMSNorm
,
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
config
.
hidden_size
,
orig_module
.
variance_epsilon
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
batch_size_tensor
:
torch
.
Tensor
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
#return self.forward_native(x, residual)
if
batch_size_tensor
is
None
:
return
self
.
forward_native
(
x
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
batch_size_tensor
,
self
.
variance_epsilon
)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return
x
,
residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
batch_size_tensor
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
class
KQwen3MoeRMSNorm
(
Qwen3MoeRMSNorm
,
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
hidden_size
,
orig_module
.
variance_epsilon
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
batch_size_tensor
:
torch
.
Tensor
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
#return self.forward_native(x, residual)
bsz
,
hidden_size
=
x
.
shape
x
=
x
.
view
(
-
1
,
self
.
orig_module
.
hidden_size
)
if
batch_size_tensor
is
None
:
return
self
.
forward_native
(
x
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
batch_size_tensor
,
self
.
variance_epsilon
)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return
x
,
residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
batch_size_tensor
,
self
.
variance_epsilon
)
out
=
out
.
view
(
bsz
,
hidden_size
)
return
out
def
forward_native
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
ktransformers/operators/mlp.py
View file @
3f9bbf11
...
@@ -4,8 +4,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
...
@@ -4,8 +4,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
torch.nn
as
nn
import
torch.nn
as
nn
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3MLP
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3MLP
from
ktransformers.models.modeling_qwen2_moe
import
Qwen2MoeMLP
class
kDeepseekV3MLP
(
DeepseekV3MLP
,
BaseInjectedModule
):
class
kDeepseekV3MLP
(
DeepseekV3MLP
,
BaseInjectedModule
):
def
__init__
(
self
,
def
__init__
(
self
,
key
:
str
,
key
:
str
,
...
@@ -18,6 +17,21 @@ class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
...
@@ -18,6 +17,21 @@ class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
hidden_size
,
orig_module
.
intermediate_size
)
orig_module
.
hidden_size
,
orig_module
.
intermediate_size
)
def
forward
(
self
,
x
,
bsz_tensor
):
down_proj
=
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
,
bsz_tensor
))
*
self
.
up_proj
(
x
,
bsz_tensor
),
bsz_tensor
)
return
down_proj
class
KQwen2MoeMLP
(
Qwen2MoeMLP
,
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
intermediate_size
)
def
forward
(
self
,
x
,
bsz_tensor
):
def
forward
(
self
,
x
,
bsz_tensor
):
down_proj
=
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
,
bsz_tensor
))
*
self
.
up_proj
(
x
,
bsz_tensor
),
bsz_tensor
)
down_proj
=
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
,
bsz_tensor
))
*
self
.
up_proj
(
x
,
bsz_tensor
),
bsz_tensor
)
return
down_proj
return
down_proj
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
View file @
3f9bbf11
...
@@ -56,7 +56,7 @@
...
@@ -56,7 +56,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
class
:
ktransformers.operators.
balance_serve_
attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
View file @
3f9bbf11
...
@@ -50,7 +50,7 @@
...
@@ -50,7 +50,7 @@
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
class
:
ktransformers.operators.
balance_serve_
attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
kwargs
:
generate_device
:
"
cuda"
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
prefill_device
:
"
cuda"
...
...
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
0 → 100644
View file @
3f9bbf11
-
match
:
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbedding
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
# - match:
# name: "^model\\.layers\\..*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace:
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cuda"
# prefill_device: "cuda"
# generate_op: "VLinearMarlin"
# prefill_op: "KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*mlp
\\
.shared_expert_gate).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace
:
class
:
ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.balance_serve_attention.KQwen2MoeAttention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KQwen2MoeModel"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm
replace
:
class
:
ktransformers.operators.layernorm.KQwen2MoeRMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP
replace
:
class
:
ktransformers.operators.mlp.KQwen2MoeMLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
0 → 100644
View file @
3f9bbf11
-
match
:
class
:
ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbedding
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
# - match:
# name: "^model\\.layers\\..*$" # regular expression
# class: torch.nn.Linear # only match modules matching name and class simultaneously
# replace:
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cuda"
# prefill_device: "cuda"
# generate_op: "VLinearMarlin"
# prefill_op: "KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*mlp
\\
.shared_expert_gate).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
replace
:
class
:
ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.balance_serve_attention.KQwen3MoeAttention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KQwen2MoeModel"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm
replace
:
class
:
ktransformers.operators.layernorm.KQwen3MoeRMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP
replace
:
class
:
ktransformers.operators.mlp.KQwen2MoeMLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
ktransformers/server/args.py
View file @
3f9bbf11
...
@@ -20,6 +20,7 @@ class ArgumentParser:
...
@@ -20,6 +20,7 @@ class ArgumentParser:
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
self
.
cfg
.
model_device
,
help
=
"Warning: Abandoning this parameter"
"--device"
,
type
=
str
,
default
=
self
.
cfg
.
model_device
,
help
=
"Warning: Abandoning this parameter"
)
)
parser
.
add_argument
(
"--architectures"
,
type
=
str
,
default
=
self
.
cfg
.
model_name
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
default
=
self
.
cfg
.
gguf_path
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
default
=
self
.
cfg
.
gguf_path
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
...
@@ -93,6 +94,7 @@ class ArgumentParser:
...
@@ -93,6 +94,7 @@ class ArgumentParser:
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--user_algorithm"
,
type
=
str
,
default
=
self
.
cfg
.
user_algorithm
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--force_think"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
user_force_think
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
parser
.
add_argument
(
"--use_cuda_graph"
,
action
=
argparse
.
BooleanOptionalAction
,
type
=
bool
,
default
=
self
.
cfg
.
use_cuda_graph
)
# parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=False)
# web config
# web config
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
parser
.
add_argument
(
"--web_cross_domain"
,
type
=
bool
,
default
=
self
.
cfg
.
web_cross_domain
)
...
@@ -137,7 +139,7 @@ class ArgumentParser:
...
@@ -137,7 +139,7 @@ class ArgumentParser:
self
.
cfg
.
server_port
=
args
.
port
self
.
cfg
.
server_port
=
args
.
port
self
.
cfg
.
user_force_think
=
args
.
force_think
self
.
cfg
.
user_force_think
=
args
.
force_think
args
.
gpu_memory_size
=
args
.
cache_lens
*
2
*
576
*
61
args
.
gpu_memory_size
=
4
*
1024
*
1024
*
1024
# TODO: set this to the actual GPU memory size
self
.
cfg
.
gpu_memory_size
=
args
.
gpu_memory_size
self
.
cfg
.
gpu_memory_size
=
args
.
gpu_memory_size
free_ports
=
get_free_ports
(
3
,
[
args
.
port
])
free_ports
=
get_free_ports
(
3
,
[
args
.
port
])
args
.
sched_port
=
free_ports
[
0
]
args
.
sched_port
=
free_ports
[
0
]
...
...
ktransformers/server/backend/interfaces/balance_serve.py
View file @
3f9bbf11
from
typing
import
Any
,
AsyncIterator
,
List
,
Optional
,
Set
from
typing
import
Any
,
AsyncIterator
,
List
,
Optional
,
Set
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
,
KGQACache
from
transformers
import
(
from
transformers
import
(
AutoTokenizer
,
AutoTokenizer
,
AutoConfig
,
AutoConfig
,
...
@@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
...
@@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.models.custom_modeling_qwen2_moe
import
KQwen2MoeForCausalLM
from
ktransformers.models.custom_modeling_qwen3_moe
import
KQwen3MoeForCausalLM
from
ktransformers.models.configuration_qwen3_moe
import
Qwen3MoeConfig
from
ktransformers.server.balance_serve.inference.model_runner
import
ModelRunner
from
ktransformers.server.balance_serve.inference.model_runner
import
ModelRunner
from
ktransformers.server.balance_serve.inference.sampling.sampler
import
Sampler
,
SamplingOptions
from
ktransformers.server.balance_serve.inference.sampling.sampler
import
Sampler
,
SamplingOptions
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
...
@@ -53,8 +56,10 @@ ktransformer_rules_dir = (
...
@@ -53,8 +56,10 @@ ktransformer_rules_dir = (
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
)
)
default_optimize_rules
=
{
default_optimize_rules
=
{
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat-serve.yaml"
,
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"Moonlight-16B-A3B-serve.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct-serve.yaml"
,
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-serve.yaml"
,
"Qwen3MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen3Moe-serve.yaml"
,
}
}
...
@@ -105,7 +110,7 @@ class Engine:
...
@@ -105,7 +110,7 @@ class Engine:
model_runner
:
ModelRunner
model_runner
:
ModelRunner
sampler
:
Sampler
sampler
:
Sampler
query_manager
:
QueryManager
query_manager
:
QueryManager
cache
:
KDeepSeekV3Cache
cache
:
KDeepSeekV3Cache
|
KGQACache
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
,
kvcache_event
:
Event
=
None
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
,
kvcache_event
:
Event
=
None
):
self
.
args
=
args
self
.
args
=
args
...
@@ -117,17 +122,32 @@ class Engine:
...
@@ -117,17 +122,32 @@ class Engine:
self
.
device
=
self
.
args
.
device
self
.
device
=
self
.
args
.
device
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
updates
=
[]
self
.
updates
=
[]
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
try
:
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
except
:
if
args
.
model_name
==
"Qwen3Moe"
:
config
=
Qwen3MoeConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
else
:
assert
False
,
f
"model
{
args
.
model_name
}
not supported"
self
.
gen_queue
=
generated_token_queue
self
.
gen_queue
=
generated_token_queue
with
torch
.
device
(
"meta"
):
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
self
.
model
=
KDeepseekV2ForCausalLM
(
config
,
self
.
cache
)
self
.
model
=
KDeepseekV2ForCausalLM
(
config
,
self
.
cache
)
# print(self.block_num)
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
or
config
.
architectures
[
0
]
==
"Qwen3MoeForCausalLM"
:
self
.
cache
=
KGQACache
(
config
,
self
.
args
.
page_size
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
self
.
model
=
KQwen2MoeForCausalLM
(
config
,
self
.
cache
)
else
:
self
.
model
=
KQwen3MoeForCausalLM
(
config
,
self
.
cache
)
context
=
zmq
.
Context
()
context
=
zmq
.
Context
()
...
@@ -176,9 +196,12 @@ class Engine:
...
@@ -176,9 +196,12 @@ class Engine:
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
#@TODO add config
#@TODO add config
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
or
config
.
architectures
[
0
]
==
"Qwen3MoeForCausalLM"
:
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
1024
,
args
.
max_batch_size
,
self
.
block_num
)
# TODO: 1024 is a magic number(max_batch_tokens)
else
:
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
self
.
model_runner
=
ModelRunner
(
self
.
model
,
self
.
device
,
self
.
args
.
use_cuda_graph
,
page_size
=
args
.
page_size
)
self
.
model_runner
=
ModelRunner
(
self
.
model
,
self
.
device
,
self
.
args
.
use_cuda_graph
,
page_size
=
args
.
page_size
,
block_num
=
self
.
block_num
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
query_manager
=
QueryManager
(
device
=
self
.
device
,
page_size
=
args
.
page_size
)
self
.
query_manager
=
QueryManager
(
device
=
self
.
device
,
page_size
=
args
.
page_size
)
...
@@ -231,7 +254,7 @@ class Engine:
...
@@ -231,7 +254,7 @@ class Engine:
if
self
.
batch
is
not
None
:
if
self
.
batch
is
not
None
:
self
.
model_runner
.
sync
()
self
.
model_runner
.
sync
()
print
(
f
"Model execution time (GPU):
{
self
.
model_runner
.
model_time
:.
3
f
}
ms"
)
print
(
f
"Model execution time (GPU):
{
self
.
model_runner
.
model_time
:.
3
f
}
ms
,
{
1000
/
self
.
model_runner
.
model_time
:.
3
f
}
tokens/s
"
)
# if self.rank == 0:
# if self.rank == 0:
generated_tokens
,
probs
=
self
.
sampling
(
self
.
model_runner
.
output
)
generated_tokens
,
probs
=
self
.
sampling
(
self
.
model_runner
.
output
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment