Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
0da3792b
Commit
0da3792b
authored
Apr 28, 2025
by
djw
Browse files
support qwen3
parent
3f9bbf11
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
3 deletions
+9
-3
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
+1
-0
ktransformers/operators/balance_serve_attention.py
ktransformers/operators/balance_serve_attention.py
+5
-2
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
+1
-0
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
+1
-0
ktransformers/server/balance_serve/inference/model_runner.py
ktransformers/server/balance_serve/inference/model_runner.py
+1
-1
No files found.
csrc/custom_marlin/gptq_marlin/gptq_marlin.cu
View file @
0da3792b
...
...
@@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
*
locks
// extra global storage for barrier synchronization
)
{
int
prob_m
=
*
prob_m_ptr
;
prob_m
=
min
(
prob_m
,
1024
);
const
int
thread_m_blocks
=
min
(
div_ceil
(
prob_m
,
16
),
template_thread_m_blocks
);
if
(
prob_m
>
16
*
thread_m_blocks
)
prob_m
=
(
16
*
thread_m_blocks
)
*
div_ceil
(
prob_m
,
(
16
*
thread_m_blocks
));
...
...
ktransformers/operators/balance_serve_attention.py
View file @
0da3792b
...
...
@@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
):
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_norm
(
self
.
q_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors
)
key_states
=
self
.
k_norm
(
self
.
k_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors
)
bsz_tensors_q
=
bsz_tensors
*
self
.
num_heads
bsz_tensors_kv
=
bsz_tensors
*
self
.
num_key_value_heads
query_states
=
self
.
q_norm
(
self
.
q_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors_q
)
key_states
=
self
.
k_norm
(
self
.
k_proj
(
hidden_states
,
bsz_tensors
),
bsz_tensors_kv
)
value_states
=
self
.
v_proj
(
hidden_states
,
bsz_tensors
)
...
...
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
View file @
0da3792b
...
...
@@ -56,6 +56,7 @@
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
backend
:
"
AMXInt8"
# or "AMXBF16" or "llamafile" (default)
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
...
...
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
View file @
0da3792b
...
...
@@ -56,6 +56,7 @@
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
backend
:
"
AMXInt8"
# or "AMXBF16" or "llamafile" (default)
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
...
...
ktransformers/server/balance_serve/inference/model_runner.py
View file @
0da3792b
...
...
@@ -85,7 +85,7 @@ class ModelRunner:
elif
isinstance
(
self
.
model
,
KQwen2MoeForCausalLM
)
or
isinstance
(
self
.
model
,
KQwen3MoeForCausalLM
):
self
.
model
.
flash_infer_attn_plan
(
batch
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
num_q_heads
=
self
.
model
.
config
.
num_attention_heads
,
num_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
head_dim
=
self
.
model
.
config
.
hidden_size
//
self
.
model
.
config
.
num_attention_heads
,
head_dim
=
128
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
,
cuda_graph_idx
=
cuda_graph_idx
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment