Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wqshmzh
ktransformers
Commits
0da3792b
You need to sign in or sign up before continuing.
Commit
0da3792b
authored
Apr 28, 2025
by
djw
Browse files
support qwen3
parent
3f9bbf11
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
3 deletions
+9
-3
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
+1
-0
ktransformers/operators/balance_serve_attention.py
ktransformers/operators/balance_serve_attention.py
+5
-2
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
+1
-0
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
+1
-0
ktransformers/server/balance_serve/inference/model_runner.py
ktransformers/server/balance_serve/inference/model_runner.py
+1
-1
No files found.
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
View file @
0da3792b
...
@@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
// extra global storage for barrier synchronization
)
{
)
{
int
prob_m
=
*
prob_m_ptr
;
int
prob_m
=
*
prob_m_ptr
;
prob_m
=
min
(
prob_m
,
1024
);
const
int
thread_m_blocks
=
min
(
div_ceil
(
prob_m
,
16
),
template_thread_m_blocks
);
const
int
thread_m_blocks
=
min
(
div_ceil
(
prob_m
,
16
),
template_thread_m_blocks
);
if
(
prob_m
>
16
*
thread_m_blocks
)
if
(
prob_m
>
16
*
thread_m_blocks
)
prob_m
=
(
16
*
thread_m_blocks
)
*
div_ceil
(
prob_m
,
(
16
*
thread_m_blocks
));
prob_m
=
(
16
*
thread_m_blocks
)
*
div_ceil
(
prob_m
,
(
16
*
thread_m_blocks
));
...
...
ktransformers/operators/balance_serve_attention.py
View file @
0da3792b
...
@@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
...
@@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
):
):
q_len
,
_
=
hidden_states
.
size
()
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_norm
(
self
.
q_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors
)
bsz_tensors_q
=
bsz_tensors
*
self
.
num_heads
key_states
=
self
.
k_norm
(
self
.
k_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors
)
bsz_tensors_kv
=
bsz_tensors
*
self
.
num_key_value_heads
query_states
=
self
.
q_norm
(
self
.
q_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors_q
)
key_states
=
self
.
k_norm
(
self
.
k_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors_kv
)
value_states
=
self
.
v_proj
(
hidden_states
,
bsz_tensors
)
value_states
=
self
.
v_proj
(
hidden_states
,
bsz_tensors
)
...
...
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
View file @
0da3792b
...
@@ -56,6 +56,7 @@
...
@@ -56,6 +56,7 @@
generate_device
:
"
cpu"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
out_device
:
"
cuda"
backend
:
"
AMXInt8"
# or "AMXBF16" or "llamafile" (default)
recursive
:
False
# don't recursively inject submodules of this module
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
...
...
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
View file @
0da3792b
...
@@ -56,6 +56,7 @@
...
@@ -56,6 +56,7 @@
generate_device
:
"
cpu"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
out_device
:
"
cuda"
backend
:
"
AMXInt8"
# or "AMXBF16" or "llamafile" (default)
recursive
:
False
# don't recursively inject submodules of this module
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
...
...
ktransformers/server/balance_serve/inference/model_runner.py
View file @
0da3792b
...
@@ -85,7 +85,7 @@ class ModelRunner:
...
@@ -85,7 +85,7 @@ class ModelRunner:
elif
isinstance
(
self
.
model
,
KQwen2MoeForCausalLM
)
or
isinstance
(
self
.
model
,
KQwen3MoeForCausalLM
):
elif
isinstance
(
self
.
model
,
KQwen2MoeForCausalLM
)
or
isinstance
(
self
.
model
,
KQwen3MoeForCausalLM
):
self
.
model
.
flash_infer_attn_plan
(
batch
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
model
.
flash_infer_attn_plan
(
batch
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
num_q_heads
=
self
.
model
.
config
.
num_attention_heads
,
num_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
num_q_heads
=
self
.
model
.
config
.
num_attention_heads
,
num_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
head_dim
=
self
.
model
.
config
.
hidden_size
//
self
.
model
.
config
.
num_attention_heads
,
head_dim
=
128
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
,
cuda_graph_idx
=
cuda_graph_idx
)
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
,
cuda_graph_idx
=
cuda_graph_idx
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment