Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
74bb7fdc
Commit
74bb7fdc
authored
Apr 28, 2025
by
qiyuxinlin
Browse files
Merge remote-tracking branch 'dev/support-amx-2'
parents
ba92cf1a
be4b27e8
Changes
53
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
262 additions
and
196 deletions
+262
-196
ktransformers/server/args.py
ktransformers/server/args.py
+2
-1
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+32
-9
ktransformers/server/balance_serve/inference/forward_batch.py
...nsformers/server/balance_serve/inference/forward_batch.py
+1
-1
ktransformers/server/balance_serve/inference/model_runner.py
ktransformers/server/balance_serve/inference/model_runner.py
+96
-175
ktransformers/server/balance_serve/sched_rpc.py
ktransformers/server/balance_serve/sched_rpc.py
+7
-2
ktransformers/server/balance_serve/settings.py
ktransformers/server/balance_serve/settings.py
+106
-0
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+1
-0
ktransformers/server/requirements.txt
ktransformers/server/requirements.txt
+1
-1
ktransformers/tests/test_speed.py
ktransformers/tests/test_speed.py
+1
-1
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+12
-3
pyproject.toml
pyproject.toml
+1
-1
requirements-local_chat.txt
requirements-local_chat.txt
+1
-1
third_party/custom_flashinfer
third_party/custom_flashinfer
+1
-1
No files found.
ktransformers/server/args.py
View file @
74bb7fdc
...
@@ -20,6 +20,7 @@ class ArgumentParser:
...
@@ -20,6 +20,7 @@ class ArgumentParser:
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
self
.
cfg
.
model_device
,
help
=
"Warning: Abandoning this parameter"
"--device"
,
type
=
str
,
default
=
self
.
cfg
.
model_device
,
help
=
"Warning: Abandoning this parameter"
)
)
parser
.
add_argument
(
"--architectures"
,
type
=
str
,
default
=
self
.
cfg
.
model_name
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
default
=
self
.
cfg
.
gguf_path
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
default
=
self
.
cfg
.
gguf_path
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
...
@@ -137,7 +138,7 @@ class ArgumentParser:
...
@@ -137,7 +138,7 @@ class ArgumentParser:
self
.
cfg
.
server_port
=
args
.
port
self
.
cfg
.
server_port
=
args
.
port
self
.
cfg
.
user_force_think
=
args
.
force_think
self
.
cfg
.
user_force_think
=
args
.
force_think
args
.
gpu_memory_size
=
args
.
cache_lens
*
2
*
576
*
61
args
.
gpu_memory_size
=
4
*
1024
*
1024
*
1024
# TODO: set this to the actual GPU memory size
self
.
cfg
.
gpu_memory_size
=
args
.
gpu_memory_size
self
.
cfg
.
gpu_memory_size
=
args
.
gpu_memory_size
free_ports
=
get_free_ports
(
3
,
[
args
.
port
])
free_ports
=
get_free_ports
(
3
,
[
args
.
port
])
args
.
sched_port
=
free_ports
[
0
]
args
.
sched_port
=
free_ports
[
0
]
...
...
ktransformers/server/backend/interfaces/balance_serve.py
View file @
74bb7fdc
from
typing
import
Any
,
AsyncIterator
,
List
,
Optional
,
Set
from
typing
import
Any
,
AsyncIterator
,
List
,
Optional
,
Set
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
,
KGQACache
from
transformers
import
(
from
transformers
import
(
AutoTokenizer
,
AutoTokenizer
,
AutoConfig
,
AutoConfig
,
...
@@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
...
@@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.models.custom_modeling_qwen2_moe
import
KQwen2MoeForCausalLM
from
ktransformers.models.custom_modeling_qwen3_moe
import
KQwen3MoeForCausalLM
from
ktransformers.models.configuration_qwen3_moe
import
Qwen3MoeConfig
from
ktransformers.server.balance_serve.inference.model_runner
import
ModelRunner
from
ktransformers.server.balance_serve.inference.model_runner
import
ModelRunner
from
ktransformers.server.balance_serve.inference.sampling.sampler
import
Sampler
,
SamplingOptions
from
ktransformers.server.balance_serve.inference.sampling.sampler
import
Sampler
,
SamplingOptions
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
...
@@ -53,8 +56,10 @@ ktransformer_rules_dir = (
...
@@ -53,8 +56,10 @@ ktransformer_rules_dir = (
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
)
)
default_optimize_rules
=
{
default_optimize_rules
=
{
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml",
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat-serve.yaml"
,
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat-serve.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct-serve.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-serve.yaml"
,
"Qwen3MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen3Moe-serve.yaml"
,
}
}
...
@@ -105,7 +110,7 @@ class Engine:
...
@@ -105,7 +110,7 @@ class Engine:
model_runner
:
ModelRunner
model_runner
:
ModelRunner
sampler
:
Sampler
sampler
:
Sampler
query_manager
:
QueryManager
query_manager
:
QueryManager
cache
:
KDeepSeekV3Cache
cache
:
KDeepSeekV3Cache
|
KGQACache
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
,
kvcache_event
:
Event
=
None
):
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
,
kvcache_event
:
Event
=
None
):
self
.
args
=
args
self
.
args
=
args
...
@@ -117,17 +122,32 @@ class Engine:
...
@@ -117,17 +122,32 @@ class Engine:
self
.
device
=
self
.
args
.
device
self
.
device
=
self
.
args
.
device
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
updates
=
[]
self
.
updates
=
[]
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
try
:
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
except
:
if
args
.
model_name
==
"Qwen3Moe"
:
config
=
Qwen3MoeConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
else
:
assert
False
,
f
"model
{
args
.
model_name
}
not supported"
self
.
gen_queue
=
generated_token_queue
self
.
gen_queue
=
generated_token_queue
with
torch
.
device
(
"meta"
):
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
self
.
model
=
KDeepseekV2ForCausalLM
(
config
,
self
.
cache
)
self
.
model
=
KDeepseekV2ForCausalLM
(
config
,
self
.
cache
)
# print(self.block_num)
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
or
config
.
architectures
[
0
]
==
"Qwen3MoeForCausalLM"
:
self
.
cache
=
KGQACache
(
config
,
self
.
args
.
page_size
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
self
.
model
=
KQwen2MoeForCausalLM
(
config
,
self
.
cache
)
else
:
self
.
model
=
KQwen3MoeForCausalLM
(
config
,
self
.
cache
)
context
=
zmq
.
Context
()
context
=
zmq
.
Context
()
...
@@ -176,9 +196,12 @@ class Engine:
...
@@ -176,9 +196,12 @@ class Engine:
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
#@TODO add config
#@TODO add config
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
or
config
.
architectures
[
0
]
==
"Qwen3MoeForCausalLM"
:
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
1024
,
args
.
max_batch_size
,
self
.
block_num
)
# TODO: 1024 is a magic number(max_batch_tokens)
else
:
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
self
.
model_runner
=
ModelRunner
(
self
.
model
,
self
.
device
,
self
.
args
.
use_cuda_graph
,
page_size
=
args
.
page_size
)
self
.
model_runner
=
ModelRunner
(
self
.
model
,
self
.
device
,
self
.
args
.
use_cuda_graph
,
page_size
=
args
.
page_size
,
block_num
=
self
.
block_num
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
query_manager
=
QueryManager
(
device
=
self
.
device
,
page_size
=
args
.
page_size
)
self
.
query_manager
=
QueryManager
(
device
=
self
.
device
,
page_size
=
args
.
page_size
)
...
@@ -231,7 +254,7 @@ class Engine:
...
@@ -231,7 +254,7 @@ class Engine:
if
self
.
batch
is
not
None
:
if
self
.
batch
is
not
None
:
self
.
model_runner
.
sync
()
self
.
model_runner
.
sync
()
print
(
f
"Model execution time (GPU):
{
self
.
model_runner
.
model_time
:.
3
f
}
ms"
)
print
(
f
"Model execution time (GPU):
{
self
.
model_runner
.
model_time
:.
3
f
}
ms
,
{
1000
/
self
.
model_runner
.
model_time
:.
3
f
}
tokens/s
"
)
# if self.rank == 0:
# if self.rank == 0:
generated_tokens
,
probs
=
self
.
sampling
(
self
.
model_runner
.
output
)
generated_tokens
,
probs
=
self
.
sampling
(
self
.
model_runner
.
output
)
...
...
ktransformers/server/balance_serve/inference/forward_batch.py
View file @
74bb7fdc
...
@@ -281,4 +281,4 @@ class ForwardBatchOutput:
...
@@ -281,4 +281,4 @@ class ForwardBatchOutput:
self
.
generated_tokens_num
=
[]
self
.
generated_tokens_num
=
[]
self
.
top_ps
=
[]
self
.
top_ps
=
[]
self
.
temperatures
=
[]
self
.
temperatures
=
[]
pass
self
.
num_batchs
=
1
\ No newline at end of file
\ No newline at end of file
ktransformers/server/balance_serve/inference/model_runner.py
View file @
74bb7fdc
...
@@ -27,6 +27,8 @@ from ktransformers.server.balance_serve.inference.forward_batch import ForwardBa
...
@@ -27,6 +27,8 @@ from ktransformers.server.balance_serve.inference.forward_batch import ForwardBa
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.models.custom_modeling_qwen2_moe
import
KQwen2MoeForCausalLM
from
ktransformers.models.custom_modeling_qwen3_moe
import
KQwen3MoeForCausalLM
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
from
ktransformers.server.balance_serve.settings
import
sched_ext
from
ktransformers.server.balance_serve.settings
import
sched_ext
...
@@ -40,11 +42,11 @@ def deduplicate_and_sort(lst):
...
@@ -40,11 +42,11 @@ def deduplicate_and_sort(lst):
class
ModelRunner
:
class
ModelRunner
:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
model
:
KDeepseekV3ForCausalLM
model
:
KDeepseekV3ForCausalLM
|
KQwen2MoeForCausalLM
|
KQwen3MoeForCausalLM
input
:
ForwardBatchInput
|
list
[
ForwardBatchInput
]
input
:
ForwardBatchInput
|
list
[
ForwardBatchInput
]
output
:
ForwardBatchOutput
output
:
ForwardBatchOutput
def
__init__
(
self
,
model
=
None
,
device
=
None
,
use_cuda_graph
=
False
,
max_decode_batch_size
=
1
,
max_chunk_size
=
4096
,
num_mini_batches
:
int
=
1
,
page_size
=
256
):
def
__init__
(
self
,
model
=
None
,
device
=
None
,
use_cuda_graph
=
False
,
max_decode_batch_size
=
1
,
max_chunk_size
=
4096
,
num_mini_batches
:
int
=
1
,
page_size
=
256
,
block_num
=
8
):
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
device
)
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
device
)
# 先注释掉
# 先注释掉
...
@@ -58,120 +60,92 @@ class ModelRunner:
...
@@ -58,120 +60,92 @@ class ModelRunner:
self
.
use_cuda_graph
=
use_cuda_graph
self
.
use_cuda_graph
=
use_cuda_graph
self
.
model_time
=
0
self
.
model_time
=
0
self
.
page_size
=
page_size
self
.
page_size
=
page_size
self
.
block_num
=
block_num
# GPU timing for model execution
# GPU timing for model execution
self
.
start_model_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
start_model_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
end_model_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
end_model_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
if
isinstance
(
self
.
cuda_graphs
,
list
):
self
.
graphs
=
[
torch
.
cuda
.
CUDAGraph
()
for
_
in
range
(
len
(
self
.
cuda_graphs
))]
self
.
graphs
=
[
torch
.
cuda
.
CUDAGraph
()
for
_
in
range
(
len
(
self
.
cuda_graphs
))]
self
.
page_idx_buf
=
[
torch
.
zeros
([
self
.
cuda_graphs
[
i
]],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
for
i
in
range
(
len
(
self
.
cuda_graphs
))]
self
.
page_idx_buf
=
[
torch
.
zeros
([
self
.
cuda_graphs
[
i
]],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
for
i
in
range
(
len
(
self
.
cuda_graphs
))]
self
.
page_offset_buf
=
[
torch
.
zeros
([
self
.
cuda_graphs
[
i
]],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
for
i
in
range
(
len
(
self
.
cuda_graphs
))]
self
.
page_offset_buf
=
[
torch
.
zeros
([
self
.
cuda_graphs
[
i
]],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
for
i
in
range
(
len
(
self
.
cuda_graphs
))]
else
:
self
.
graphs
=
torch
.
cuda
.
CUDAGraph
()
self
.
page_idx_buf
=
torch
.
zeros
([
self
.
cuda_graphs
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
page_offset_buf
=
torch
.
zeros
([
self
.
cuda_graphs
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
num_mini_batches
=
num_mini_batches
self
.
num_mini_batches
=
num_mini_batches
self
.
max_chunk_size
=
max_chunk_size
self
.
max_chunk_size
=
max_chunk_size
self
.
bsz_tensor_buf
=
torch
.
empty
((
1
,
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
bsz_tensor_buf
=
torch
.
empty
((
1
,
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
num_tokens_tensor_buf
=
torch
.
empty
((
1
,
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
num_tokens_tensor_buf
=
torch
.
empty
((
1
,
),
dtype
=
torch
.
int32
,
device
=
device
)
def
warmup
(
self
):
def
capture_graphs
(
cuda_graph_idx
=-
1
):
def
model_attn_plan
(
self
,
batch
,
cuda_graph_idx
=
0
):
if
cuda_graph_idx
!=
-
1
:
if
isinstance
(
self
.
model
,
KDeepseekV3ForCausalLM
):
with
torch
.
cuda
.
graph
(
self
.
graphs
[
cuda_graph_idx
],
pool
=
self
.
graph_memory_pool
,
stream
=
self
.
stream
):
self
.
model
.
flash_infer_attn_plan
(
batch
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
outputs_buf
[
cuda_graph_idx
]
=
self
.
model
(
self
.
input
[
cuda_graph_idx
],
self
.
features_buf
[
cuda_graph_idx
],
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
[
cuda_graph_idx
],
self
.
page_offset_buf
[
cuda_graph_idx
],
cuda_graph_idx
=
cuda_graph_idx
)
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
self
.
graph_memory_pool
=
self
.
graphs
[
cuda_graph_idx
].
pool
()
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
else
:
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
with
torch
.
cuda
.
graph
(
self
.
graphs
,
pool
=
self
.
graph_memory_pool
,
stream
=
self
.
stream
):
elif
isinstance
(
self
.
model
,
KQwen2MoeForCausalLM
)
or
isinstance
(
self
.
model
,
KQwen3MoeForCausalLM
):
self
.
outputs_buf
=
self
.
model
(
self
.
input
,
self
.
features_buf
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
,
self
.
page_offset_buf
)
self
.
model
.
flash_infer_attn_plan
(
batch
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
graph_memory_pool
=
self
.
graphs
.
pool
()
num_q_heads
=
self
.
model
.
config
.
num_attention_heads
,
num_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
head_dim
=
self
.
model
.
config
.
head_dim
if
hasattr
(
self
.
model
.
config
,
'head_num'
)
else
self
.
model
.
config
.
hidden_size
//
self
.
model
.
config
.
num_attention_heads
,
if
isinstance
(
self
.
cuda_graphs
,
list
):
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
self
.
input
=
[]
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
,
cuda_graph_idx
=
cuda_graph_idx
)
self
.
features_buf
=
[]
else
:
self
.
outputs_buf
=
[]
assert
False
,
"model type not supported"
self
.
bsz_tensor_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
num_tokens_tensor_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
for
i
in
range
(
len
(
self
.
cuda_graphs
)):
prefill_query_length
=
(
self
.
cuda_graphs
[
i
]
-
Config
().
max_decode_batch_size
)
//
Config
().
max_prefill_batch_size
if
self
.
cuda_graphs
[
i
]
>
Config
().
max_decode_batch_size
else
0
#@TODO only supprot 2 prefill batch
self
.
input
.
append
(
ForwardBatchInput
.
gen_max_forward_batch
(
device
=
self
.
device
,
num_mini_batches
=
self
.
num_mini_batches
,
prefill_query_length
=
prefill_query_length
,
prefill_active_length
=
prefill_query_length
,
page_size
=
self
.
page_size
,
cuda_lens
=
self
.
cuda_graphs
[
i
]))
self
.
features_buf
.
append
(
self
.
model
.
batch_embeddings
(
self
.
input
[
i
]))
batch_size
=
self
.
input
[
i
].
minibatch
.
q_indptr
.
size
(
0
)
-
1
num_tokens
=
self
.
features_buf
[
i
][
0
].
size
(
0
)
print
(
"capturing cuda graph"
,
batch_size
,
num_tokens
)
self
.
bsz_tensor_buf
[
0
]
=
batch_size
self
.
num_tokens_tensor_buf
[
0
]
=
num_tokens
self
.
model
.
flash_infer_attn_plan
(
self
.
input
[
i
],
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
page_idx
,
page_offset
=
self
.
model
.
cache
.
get_page_table
(
self
.
input
[
i
].
minibatch
.
position_ids
,
self
.
input
[
i
].
minibatch
.
q_indptr
,
self
.
input
[
i
].
minibatch
.
kv_indptr
,
self
.
input
[
i
].
minibatch
.
kv_indices
,
self
.
num_tokens_tensor_buf
)
self
.
page_idx_buf
[
i
][:
num_tokens
].
copy_
(
page_idx
[:
num_tokens
])
self
.
page_offset_buf
[
i
][:
num_tokens
].
copy_
(
page_offset
[:
num_tokens
])
self
.
page_idx_buf
[
i
][
num_tokens
:].
fill_
(
self
.
model
.
cache
.
max_cache_len
//
self
.
model
.
cache
.
page_size
-
1
)
self
.
outputs_buf
.
append
(
None
)
torch
.
cuda
.
synchronize
()
for
warm_up_iters
in
range
(
11
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
outputs_buf
[
i
]
=
self
.
model
(
self
.
input
[
i
],
self
.
features_buf
[
i
],
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
[
i
],
self
.
page_offset_buf
[
i
])
torch
.
cuda
.
synchronize
()
capture_graphs
(
i
)
def
warmup
(
self
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
def
capture_graphs
(
cuda_graph_idx
):
self
.
graphs
[
i
].
replay
()
with
torch
.
cuda
.
graph
(
self
.
graphs
[
cuda_graph_idx
],
pool
=
self
.
graph_memory_pool
,
stream
=
self
.
stream
):
self
.
outputs_buf
[
cuda_graph_idx
]
=
self
.
model
(
self
.
input
[
cuda_graph_idx
],
self
.
features_buf
[
cuda_graph_idx
],
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
[
cuda_graph_idx
],
self
.
page_offset_buf
[
cuda_graph_idx
],
cuda_graph_idx
=
cuda_graph_idx
)
self
.
graph_memory_pool
=
self
.
graphs
[
cuda_graph_idx
].
pool
()
self
.
sync
(
calc_time
=
False
)
self
.
input
=
[]
print
(
f
"cuda_graph:
{
i
+
1
}
/
{
len
(
self
.
cuda_graphs
)
}
, warmup finished."
)
self
.
features_buf
=
[]
else
:
self
.
outputs_buf
=
[]
self
.
input
=
ForwardBatchInput
.
gen_max_forward_batch
(
device
=
self
.
device
,
num_mini_batches
=
self
.
num_mini_batches
)
self
.
bsz_tensor_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
num_tokens_tensor_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
for
i
in
range
(
len
(
self
.
cuda_graphs
)):
prefill_query_length
=
(
self
.
cuda_graphs
[
i
]
-
Config
().
max_decode_batch_size
)
//
Config
().
max_prefill_batch_size
if
self
.
cuda_graphs
[
i
]
>
Config
().
max_decode_batch_size
else
0
#@TODO only supprot 2 prefill batch
self
.
input
.
append
(
ForwardBatchInput
.
gen_max_forward_batch
(
device
=
self
.
device
,
num_mini_batches
=
self
.
num_mini_batches
,
prefill_query_length
=
prefill_query_length
,
prefill_active_length
=
prefill_query_length
,
page_size
=
self
.
page_size
,
cuda_lens
=
self
.
cuda_graphs
[
i
]))
self
.
features_buf
=
self
.
model
.
batch_embeddings
(
self
.
input
)
self
.
features_buf
.
append
(
self
.
model
.
batch_embeddings
(
self
.
input
[
i
]))
batch_size
=
self
.
input
.
minibatch
.
q_indptr
.
size
(
0
)
-
1
batch_size
=
self
.
input
[
i
].
minibatch
.
q_indptr
.
size
(
0
)
-
1
num_tokens
=
self
.
features_buf
[
0
].
size
(
0
)
num_tokens
=
self
.
features_buf
[
i
][
0
].
size
(
0
)
print
(
"capturing cuda graph"
,
batch_size
,
num_tokens
)
if
isinstance
(
self
.
model
,
KQwen2MoeForCausalLM
)
or
isinstance
(
self
.
model
,
KQwen3MoeForCausalLM
):
self
.
bsz_tensor_buf
=
torch
.
tensor
([
batch_size
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
model
.
init_wrapper
(
self
.
use_cuda_graph
,
self
.
device
,
num_tokens
,
batch_size
,
self
.
block_num
,
i
)
# TODO: 1024 is a magic number(max_batch_tokens)
self
.
num_tokens_tensor_buf
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
bsz_tensor_buf
[
0
]
=
batch_size
self
.
num_tokens_tensor_buf
[
0
]
=
num_tokens
self
.
model_attn_plan
(
self
.
input
[
i
],
i
)
page_idx
,
page_offset
=
self
.
model
.
cache
.
get_page_table
(
self
.
input
[
i
].
minibatch
.
position_ids
,
self
.
input
[
i
].
minibatch
.
q_indptr
,
self
.
input
[
i
].
minibatch
.
kv_indptr
,
self
.
input
[
i
].
minibatch
.
kv_indices
,
self
.
num_tokens_tensor_buf
)
self
.
model
.
flash_infer_attn_plan
(
self
.
input
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
page_idx
,
page_offset
=
self
.
model
.
cache
.
get_page_table
(
self
.
input
.
minibatch
.
position_ids
,
self
.
input
.
minibatch
.
q_indptr
,
self
.
input
.
minibatch
.
kv_indptr
,
self
.
input
.
minibatch
.
kv_indices
,
self
.
num_tokens_tensor_buf
)
self
.
page_idx_buf
[:
num_tokens
].
copy_
(
page_idx
[:
num_tokens
])
self
.
page_offset_buf
[:
num_tokens
].
copy_
(
page_offset
[:
num_tokens
])
self
.
page_idx_buf
[
num_tokens
:].
fill_
(
self
.
model
.
cache
.
max_cache_len
//
self
.
model
.
cache
.
page_size
-
1
)
self
.
page_idx_buf
[
i
][:
num_tokens
].
copy_
(
page_idx
[:
num_tokens
])
self
.
page_offset_buf
[
i
][:
num_tokens
].
copy_
(
page_offset
[:
num_tokens
])
self
.
page_idx_buf
[
i
][
num_tokens
:].
fill_
(
self
.
model
.
cache
.
max_cache_len
//
self
.
model
.
cache
.
page_size
-
1
)
self
.
outputs_buf
.
append
(
None
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
for
warm_up_iters
in
range
(
11
):
for
warm_up_iters
in
range
(
11
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
outputs_buf
=
self
.
model
(
self
.
input
,
self
.
features_buf
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
,
self
.
page_offset_buf
)
self
.
outputs_buf
[
i
]
=
self
.
model
(
self
.
input
[
i
]
,
self
.
features_buf
[
i
]
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
[
i
]
,
self
.
page_offset_buf
[
i
],
cuda_graph_idx
=
i
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
def
capture_graphs
():
self
.
outputs_buf
[
i
].
num_batchs
=
batch_size
with
torch
.
cuda
.
graph
(
self
.
graphs
,
stream
=
self
.
stream
):
self
.
outputs_buf
=
self
.
model
(
self
.
input
,
self
.
features_buf
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
page_idx_buf
,
self
.
page_offset_buf
)
# self.graph_memory_pool = self.graphs.pool()
capture_graphs
()
capture_graphs
(
i
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
graphs
.
replay
()
self
.
graphs
[
i
]
.
replay
()
self
.
sync
(
calc_time
=
False
)
self
.
sync
(
calc_time
=
False
)
print
(
"
warmup finished."
)
print
(
f
"cuda_graph:
{
i
+
1
}
/
{
len
(
self
.
cuda_graphs
)
}
,
warmup finished."
)
def
run
(
self
,
batch
:
sched_ext
.
BatchQueryTodo
=
None
,
query_manager
:
QueryManager
=
None
):
def
run
(
self
,
batch
:
sched_ext
.
BatchQueryTodo
=
None
,
query_manager
:
QueryManager
=
None
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
...
@@ -189,107 +163,54 @@ class ModelRunner:
...
@@ -189,107 +163,54 @@ class ModelRunner:
if
isinstance
(
self
.
cuda_graphs
,
list
):
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
cuda_graph_idx
=
next
((
i
for
i
,
token
in
enumerate
(
self
.
cuda_graphs
)
if
token
>=
num_tokens
),
len
(
self
.
cuda_graphs
))
cuda_graph_idx
=
next
((
i
for
i
,
token
in
enumerate
(
self
.
cuda_graphs
)
if
token
>=
num_tokens
),
len
(
self
.
cuda_graphs
))
if
not
self
.
use_cuda_graph
:
if
cuda_graph_idx
==
len
(
self
.
cuda_graphs
):
cuda_graph_idx
=
0
assert
False
,
"num_tokens is too large"
# if cuda_graph_idx == len(self.cuda_graphs):
else
:
# assert False, "num_tokens is too large"
cuda_graph_idx
=
-
1
if
self
.
use_cuda_graph
:
if
self
.
use_cuda_graph
:
if
cuda_graph_idx
!=
-
1
:
self
.
input
[
cuda_graph_idx
].
fill
(
batch
,
query_manager
,
self
.
page_size
)
self
.
input
[
cuda_graph_idx
].
fill
(
batch
,
query_manager
,
self
.
page_size
)
else
:
self
.
input
.
fill
(
batch
,
query_manager
,
self
.
page_size
)
else
:
else
:
self
.
input
=
ForwardBatchInput
(
batch
=
batch
,
query_manager
=
query_manager
,
device
=
self
.
device
)
self
.
input
=
[
ForwardBatchInput
(
batch
=
batch
,
query_manager
=
query_manager
,
device
=
self
.
device
)
]
if
cuda_graph_idx
!=
-
1
and
self
.
use_cuda_graph
:
if
self
.
use_cuda_graph
:
self
.
features
=
self
.
model
.
batch_embeddings
(
self
.
input
[
cuda_graph_idx
],
device
=
self
.
device
)
self
.
features
=
self
.
model
.
batch_embeddings
(
self
.
input
[
cuda_graph_idx
],
device
=
self
.
device
)
else
:
else
:
self
.
features
=
self
.
model
.
batch_embeddings
(
self
.
input
,
device
=
self
.
device
)
self
.
features
=
self
.
model
.
batch_embeddings
(
self
.
input
[
cuda_graph_idx
]
,
device
=
self
.
device
)
self
.
bsz_tensor_buf
.
copy_
(
batch_size
)
self
.
bsz_tensor_buf
.
copy_
(
batch_size
)
self
.
num_tokens_tensor_buf
.
copy_
(
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
))
self
.
num_tokens_tensor_buf
.
copy_
(
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
))
if
self
.
use_cuda_graph
:
if
self
.
use_cuda_graph
:
if
cuda_graph_idx
!=
-
1
:
self
.
features_buf
[
cuda_graph_idx
][
0
].
copy_
(
self
.
features
[
0
],
non_blocking
=
True
)
self
.
features_buf
[
cuda_graph_idx
][
0
].
copy_
(
self
.
features
[
0
],
non_blocking
=
True
)
else
:
self
.
model_attn_plan
(
self
.
input
[
cuda_graph_idx
],
cuda_graph_idx
)
self
.
features_buf
[
0
].
copy_
(
self
.
features
[
0
],
non_blocking
=
True
)
self
.
start_model_event
.
record
(
self
.
stream
)
"""
page_idx
,
page_offset
=
self
.
model
.
cache
.
get_page_table
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
position_ids
,
self
.
input
[
cuda_graph_idx
].
minibatch
.
q_indptr
,
self
.
input
[
cuda_graph_idx
].
minibatch
.
kv_indptr
,
self
.
input
[
cuda_graph_idx
].
minibatch
.
kv_indices
,
self
.
num_tokens_tensor_buf
)
if num_tokens_0 > 64:
padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
"""
#self.input.forward_minibatchs[0].print()
# print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
# print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")
# self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)
"""
if
self
.
use_cuda_graph
:
if
self
.
use_cuda_graph
:
print("before replay features_buf", self.features_buf[0])
self
.
page_idx_buf
[
cuda_graph_idx
][:
num_tokens
].
copy_
(
page_idx
[:
num_tokens
])
print("features_buf addr", self.features_buf[0].data_ptr())
self
.
page_offset_buf
[
cuda_graph_idx
][:
num_tokens
].
copy_
(
page_offset
[:
num_tokens
])
else:
print("before run features", self.features[0])
self
.
page_idx_buf
[
cuda_graph_idx
][
num_tokens
:].
fill_
(
self
.
model
.
cache
.
max_cache_len
//
self
.
model
.
cache
.
page_size
-
1
)
"""
self
.
replay
(
cuda_graph_idx
)
if
cuda_graph_idx
!=
-
1
and
self
.
use_cuda_graph
:
self
.
output
=
ForwardBatchOutput
()
self
.
model
.
flash_infer_attn_plan
(
self
.
input
[
cuda_graph_idx
],
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
self
.
output
.
top_ps
.
append
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
top_ps
)
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
self
.
output
.
temperatures
.
append
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
temperatures
)
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
self
.
start_model_event
.
record
(
self
.
stream
)
page_idx
,
page_offset
=
self
.
model
.
cache
.
get_page_table
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
position_ids
,
self
.
input
[
cuda_graph_idx
].
minibatch
.
q_indptr
,
self
.
input
[
cuda_graph_idx
].
minibatch
.
kv_indptr
,
self
.
input
[
cuda_graph_idx
].
minibatch
.
kv_indices
,
self
.
num_tokens_tensor_buf
)
self
.
output
.
logits
.
append
(
self
.
outputs_buf
[
cuda_graph_idx
].
logits
[
0
][
self
.
input
[
cuda_graph_idx
].
minibatch
.
logits_start
].
clone
())
if
self
.
use_cuda_graph
:
self
.
page_idx_buf
[
cuda_graph_idx
][:
num_tokens
].
copy_
(
page_idx
[:
num_tokens
])
self
.
page_offset_buf
[
cuda_graph_idx
][:
num_tokens
].
copy_
(
page_offset
[:
num_tokens
])
self
.
page_idx_buf
[
cuda_graph_idx
][
num_tokens
:].
fill_
(
self
.
model
.
cache
.
max_cache_len
//
self
.
model
.
cache
.
page_size
-
1
)
self
.
replay
(
cuda_graph_idx
)
self
.
output
=
ForwardBatchOutput
()
self
.
output
.
top_ps
.
append
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
top_ps
)
self
.
output
.
temperatures
.
append
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
temperatures
)
self
.
output
.
logits
.
append
(
self
.
outputs_buf
[
cuda_graph_idx
].
logits
[
0
][
self
.
input
[
cuda_graph_idx
].
minibatch
.
logits_start
].
clone
())
else
:
self
.
output
=
self
.
model
(
self
.
input
[
cuda_graph_idx
],
self
.
features
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
page_idx
,
page_offset
)
self
.
output
.
logits
[
0
]
=
self
.
output
.
logits
[
0
][
self
.
input
[
cuda_graph_idx
].
minibatch
.
logits_start
]
self
.
end_model_event
.
record
(
self
.
stream
)
else
:
else
:
self
.
model
.
flash_infer_attn_plan
(
self
.
input
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
self
.
output
=
self
.
model
(
self
.
input
[
cuda_graph_idx
],
self
.
features
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
page_idx
,
page_offset
)
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
self
.
output
.
logits
[
0
]
=
self
.
output
.
logits
[
0
][
self
.
input
[
cuda_graph_idx
].
minibatch
.
logits_start
]
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
model
.
cache
.
page_size
,
causal
=
True
,
self
.
output
.
top_ps
.
append
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
top_ps
)
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
self
.
output
.
temperatures
.
append
(
self
.
input
[
cuda_graph_idx
].
minibatch
.
temperatures
)
self
.
start_model_event
.
record
(
self
.
stream
)
self
.
end_model_event
.
record
(
self
.
stream
)
page_idx
,
page_offset
=
self
.
model
.
cache
.
get_page_table
(
self
.
input
.
minibatch
.
position_ids
,
self
.
input
.
minibatch
.
q_indptr
,
self
.
input
.
minibatch
.
kv_indptr
,
self
.
input
.
minibatch
.
kv_indices
,
self
.
num_tokens_tensor_buf
)
if
self
.
use_cuda_graph
:
self
.
page_idx_buf
[:
num_tokens
].
copy_
(
page_idx
[:
num_tokens
])
self
.
page_offset_buf
[:
num_tokens
].
copy_
(
page_offset
[:
num_tokens
])
self
.
page_idx_buf
[
num_tokens
:].
fill_
(
self
.
model
.
cache
.
max_cache_len
//
self
.
model
.
cache
.
page_size
-
1
)
self
.
replay
(
cuda_graph_idx
)
self
.
output
=
ForwardBatchOutput
()
self
.
output
.
top_ps
.
append
(
self
.
input
.
minibatch
.
top_ps
)
self
.
output
.
temperatures
.
append
(
self
.
input
.
minibatch
.
temperatures
)
self
.
output
.
logits
.
append
(
self
.
outputs_buf
.
logits
[
0
][
self
.
input
.
minibatch
.
logits_start
].
clone
())
else
:
self
.
output
=
self
.
model
(
self
.
input
,
self
.
features
,
self
.
bsz_tensor_buf
,
self
.
num_tokens_tensor_buf
,
page_idx
,
page_offset
)
self
.
output
.
logits
[
0
]
=
self
.
output
.
logits
[
0
][
self
.
input
.
minibatch
.
logits_start
]
self
.
output
.
top_ps
.
append
(
self
.
input
.
minibatch
.
top_ps
)
self
.
output
.
temperatures
.
append
(
self
.
input
.
minibatch
.
temperatures
)
self
.
end_model_event
.
record
(
self
.
stream
)
if
not
self
.
use_cuda_graph
:
self
.
output
.
num_batchs
=
self
.
input
.
batch_size
else
:
self
.
output
.
num_batchs
=
self
.
input
[
cuda_graph_idx
].
batch_size
def
replay
(
self
,
cuda_graph_idx
=-
1
):
def
replay
(
self
,
cuda_graph_idx
=-
1
):
...
...
ktransformers/server/balance_serve/sched_rpc.py
View file @
74bb7fdc
...
@@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__)
...
@@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import
pickle
import
pickle
import
argparse
import
argparse
from
ktransformers.server.balance_serve.settings
import
sched_ext
,
create_sched_settings
from
ktransformers.server.balance_serve.settings
import
sched_ext
,
create_sched_settings
,
create_sched_settings_qwen2moe
,
create_sched_settings_qwen3moe
...
@@ -209,5 +209,10 @@ if __name__ == '__main__':
...
@@ -209,5 +209,10 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
"rb"
)
as
f
:
with
open
(
args
.
config
,
"rb"
)
as
f
:
main_args
=
pickle
.
load
(
f
)
main_args
=
pickle
.
load
(
f
)
settings
=
create_sched_settings
(
main_args
)
if
main_args
.
architectures
==
"Qwen2MoeForCausalLM"
:
settings
=
create_sched_settings_qwen2moe
(
main_args
)
elif
main_args
.
architectures
==
"Qwen3MoeForCausalLM"
:
settings
=
create_sched_settings_qwen3moe
(
main_args
)
else
:
settings
=
create_sched_settings
(
main_args
)
start_server
(
settings
,
main_args
)
start_server
(
settings
,
main_args
)
ktransformers/server/balance_serve/settings.py
View file @
74bb7fdc
...
@@ -11,6 +11,8 @@ from time import sleep
...
@@ -11,6 +11,8 @@ from time import sleep
import
sched_ext
import
sched_ext
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
ktransformers.models.configuration_qwen3_moe
import
Qwen3MoeConfig
def
create_sched_settings
(
args
):
def
create_sched_settings
(
args
):
default_sample_options
=
sched_ext
.
SampleOptions
()
default_sample_options
=
sched_ext
.
SampleOptions
()
model_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
model_dir
))
model_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
model_dir
))
...
@@ -64,7 +66,111 @@ def create_sched_settings(args):
...
@@ -64,7 +66,111 @@ def create_sched_settings(args):
return
settings
return
settings
def
create_sched_settings_qwen2moe
(
args
):
default_sample_options
=
sched_ext
.
SampleOptions
()
model_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
model_dir
))
input_model_settings
=
sched_ext
.
ModelSettings
()
input_model_settings
.
model_path
=
args
.
model_dir
input_model_settings
.
params_count
=
int
(
0
)
model_config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
input_model_settings
.
layer_count
=
model_config
.
num_hidden_layers
input_model_settings
.
num_k_heads
=
model_config
.
num_key_value_heads
# model_config["num_key_value_heads"]
input_model_settings
.
k_head_dim
=
128
input_model_settings
.
bytes_per_params
=
2
input_model_settings
.
bytes_per_kv_cache_element
=
2
settings
=
sched_ext
.
Settings
()
settings
.
model_name
=
model_name
settings
.
quant_type
=
"BF16"
settings
.
model_settings
=
input_model_settings
settings
.
page_size
=
args
.
page_size
settings
.
gpu_device_count
=
1
# tp
settings
.
gpu_device_id
=
[
i
for
i
in
range
(
settings
.
gpu_device_count
)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings
.
gpu_memory_size
=
args
.
gpu_memory_size
settings
.
memory_utilization_percentage
=
args
.
utilization_percentage
max_batch_size
=
args
.
max_batch_size
chunk_size
=
args
.
chunk_size
max_decode_batch_size
=
max_batch_size
-
2
settings
.
max_batch_size
=
max_batch_size
settings
.
recommended_chunk_prefill_token_count
=
(
chunk_size
-
max_decode_batch_size
)
//
2
settings
.
sample_options
=
default_sample_options
settings
.
sched_metrics_port
=
args
.
sched_metrics_port
settings
.
gpu_only
=
args
.
memory_gpu_only
settings
.
use_self_defined_head_dim
=
False
settings
.
self_defined_head_dim
=
576
settings
.
full_kv_cache_on_each_gpu
=
True
settings
.
k_cache_on
=
True
settings
.
v_cache_on
=
True
settings
.
kvc2_root_path
=
'/mnt/data/persist-kvc'
settings
.
kvc2_config_path
=
args
.
kvc2_config_dir
settings
.
memory_pool_size_GB
=
args
.
cpu_memory_size_GB
settings
.
evict_count
=
40
settings
.
kvc2_metrics_port
=
args
.
kvc2_metrics_port
settings
.
load_from_disk
=
False
settings
.
save_to_disk
=
True
settings
.
strategy_name
=
args
.
sched_strategy
settings
.
auto_derive
()
return
settings
def
create_sched_settings_qwen3moe
(
args
):
default_sample_options
=
sched_ext
.
SampleOptions
()
model_name
=
os
.
path
.
basename
(
os
.
path
.
normpath
(
args
.
model_dir
))
input_model_settings
=
sched_ext
.
ModelSettings
()
input_model_settings
.
model_path
=
args
.
model_dir
input_model_settings
.
params_count
=
int
(
0
)
model_config
=
Qwen3MoeConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
input_model_settings
.
layer_count
=
model_config
.
num_hidden_layers
input_model_settings
.
num_k_heads
=
model_config
.
num_key_value_heads
# model_config["num_key_value_heads"]
input_model_settings
.
k_head_dim
=
128
input_model_settings
.
bytes_per_params
=
2
input_model_settings
.
bytes_per_kv_cache_element
=
2
settings
=
sched_ext
.
Settings
()
settings
.
model_name
=
model_name
settings
.
quant_type
=
"BF16"
settings
.
model_settings
=
input_model_settings
settings
.
page_size
=
args
.
page_size
settings
.
gpu_device_count
=
1
# tp
settings
.
gpu_device_id
=
[
i
for
i
in
range
(
settings
.
gpu_device_count
)]
# settings.gpu_memory_size = args.cache_lens*576*2
settings
.
gpu_memory_size
=
args
.
gpu_memory_size
settings
.
memory_utilization_percentage
=
args
.
utilization_percentage
max_batch_size
=
args
.
max_batch_size
chunk_size
=
args
.
chunk_size
max_decode_batch_size
=
max_batch_size
-
2
settings
.
max_batch_size
=
max_batch_size
settings
.
recommended_chunk_prefill_token_count
=
(
chunk_size
-
max_decode_batch_size
)
//
2
settings
.
sample_options
=
default_sample_options
settings
.
sched_metrics_port
=
args
.
sched_metrics_port
settings
.
gpu_only
=
args
.
memory_gpu_only
settings
.
use_self_defined_head_dim
=
False
settings
.
self_defined_head_dim
=
576
settings
.
full_kv_cache_on_each_gpu
=
True
settings
.
k_cache_on
=
True
settings
.
v_cache_on
=
True
settings
.
kvc2_root_path
=
'/mnt/data/persist-kvc'
settings
.
kvc2_config_path
=
args
.
kvc2_config_dir
settings
.
memory_pool_size_GB
=
args
.
cpu_memory_size_GB
settings
.
evict_count
=
40
settings
.
kvc2_metrics_port
=
args
.
kvc2_metrics_port
settings
.
load_from_disk
=
False
settings
.
save_to_disk
=
True
settings
.
strategy_name
=
args
.
sched_strategy
settings
.
auto_derive
()
return
settings
...
...
ktransformers/server/config/config.py
View file @
74bb7fdc
...
@@ -100,6 +100,7 @@ class Config(metaclass=Singleton):
...
@@ -100,6 +100,7 @@ class Config(metaclass=Singleton):
# to make sure it consistent with previous version
# to make sure it consistent with previous version
self
.
model_path
:
str
=
self
.
model_dir
self
.
model_path
:
str
=
self
.
model_dir
self
.
model_name
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
model_name
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
architectures
:
str
=
self
.
model
.
get
(
"name"
,
""
)
self
.
model_device
:
str
=
self
.
model
.
get
(
"device"
,
"cuda:0"
)
self
.
model_device
:
str
=
self
.
model
.
get
(
"device"
,
"cuda:0"
)
self
.
gguf_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"gguf_path"
,
None
)
self
.
gguf_path
:
Optional
[
str
]
=
self
.
model
.
get
(
"gguf_path"
,
None
)
self
.
use_cuda_graph
=
self
.
model
.
get
(
"use_cuda_graph"
,
True
)
self
.
use_cuda_graph
=
self
.
model
.
get
(
"use_cuda_graph"
,
True
)
...
...
ktransformers/server/requirements.txt
View file @
74bb7fdc
torch >= 2.3.0
torch >= 2.3.0
transformers == 4.
43.2
transformers == 4.
51.3
fastapi >= 0.111.0
fastapi >= 0.111.0
langchain >= 0.2.0
langchain >= 0.2.0
blessed >= 1.20.0
blessed >= 1.20.0
...
...
ktransformers/tests/test_speed.py
View file @
74bb7fdc
...
@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model):
...
@@ -146,7 +146,7 @@ async def main(concurrent_requests , prompt, max_tokens, model):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Event Stream Request Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1
,
help
=
"Number of concurrent requests"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1
,
help
=
"Number of concurrent requests"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"DeepSeek-V3"
,
help
=
"Model name"
,
required
=
True
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"DeepSeek-V3"
,
help
=
"Model name"
)
parser
.
add_argument
(
"--prompt_lens"
,
type
=
int
,
default
=
1024
,
help
=
"prefill prompt lens, 1024 or 2048"
)
parser
.
add_argument
(
"--prompt_lens"
,
type
=
int
,
default
=
1024
,
help
=
"prefill prompt lens, 1024 or 2048"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--max_tokens"
,
type
=
int
,
default
=
50
,
help
=
"max decode tokens"
)
parser
.
add_argument
(
"--max_tokens"
,
type
=
int
,
default
=
50
,
help
=
"max decode tokens"
)
...
...
ktransformers/util/custom_gguf.py
View file @
74bb7fdc
...
@@ -912,6 +912,9 @@ def translate_name_to_gguf(name):
...
@@ -912,6 +912,9 @@ def translate_name_to_gguf(name):
name
=
name
.
replace
(
".self_attn.q_a_proj"
,
".attn_q_a"
)
name
=
name
.
replace
(
".self_attn.q_a_proj"
,
".attn_q_a"
)
name
=
name
.
replace
(
".self_attn.q_a_layernorm"
,
".attn_q_a_norm"
)
name
=
name
.
replace
(
".self_attn.q_a_layernorm"
,
".attn_q_a_norm"
)
name
=
name
.
replace
(
".self_attn.q_b_proj"
,
".attn_q_b"
)
name
=
name
.
replace
(
".self_attn.q_b_proj"
,
".attn_q_b"
)
name
=
name
.
replace
(
".self_attn.q_norm"
,
".attn_q_norm"
)
name
=
name
.
replace
(
".self_attn.k_norm"
,
".attn_k_norm"
)
name
=
name
.
replace
(
".shared_expert."
,
".shared_experts."
)
name
=
name
.
replace
(
".shared_expert."
,
".shared_experts."
)
name
=
name
.
replace
(
".shared_expert_"
,
".shared_experts_"
)
name
=
name
.
replace
(
".shared_expert_"
,
".shared_experts_"
)
...
@@ -922,17 +925,23 @@ def translate_name_to_gguf(name):
...
@@ -922,17 +925,23 @@ def translate_name_to_gguf(name):
name
=
name
.
replace
(
".mlp.shared_experts.gate_proj"
,
".ffn_gate_shexp"
)
name
=
name
.
replace
(
".mlp.shared_experts.gate_proj"
,
".ffn_gate_shexp"
)
name
=
name
.
replace
(
".mlp.shared_experts.up_proj"
,
".ffn_up_shexp"
)
name
=
name
.
replace
(
".mlp.shared_experts.up_proj"
,
".ffn_up_shexp"
)
name
=
name
.
replace
(
".mlp.shared_experts_gate"
,
".ffn_gate_inp_shexp"
)
name
=
name
.
replace
(
".mlp.shared_experts_gate"
,
".ffn_gate_inp_shexp"
)
name
=
name
.
replace
(
".mlp.experts"
,
""
)
name
=
name
.
replace
(
".mlp.experts"
,
""
)
name
=
name
.
replace
(
".mlp.experts.ffn_down_exps"
,
".ffn_down_exps"
)
name
=
name
.
replace
(
".mlp.experts.ffn_gate_exps"
,
".ffn_gate_exps"
)
name
=
name
.
replace
(
".mlp.experts.ffn_up_exps"
,
".ffn_up_exps"
)
name
=
name
.
replace
(
".block_sparse_moe.gate."
,
".ffn_gate_inp."
)
name
=
name
.
replace
(
".block_sparse_moe.gate."
,
".ffn_gate_inp."
)
name
=
name
.
replace
(
".block_sparse_moe.experts"
,
""
)
name
=
name
.
replace
(
".block_sparse_moe.experts"
,
""
)
name
=
name
.
replace
(
".feed_forward.experts"
,
""
)
name
=
name
.
replace
(
".feed_forward.router"
,
".ffn_gate_inp"
)
name
=
name
.
replace
(
".feed_forward.shared_experts.down_proj"
,
".ffn_down_shexp"
)
name
=
name
.
replace
(
".feed_forward.shared_experts.gate_proj"
,
".ffn_gate_shexp"
)
name
=
name
.
replace
(
".feed_forward.shared_experts.up_proj"
,
".ffn_up_shexp"
)
return
name
return
name
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
gguf_path
=
'/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
gguf_path
=
'/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader
=
GGUFLoader
(
gguf_path
)
loader
=
GGUFLoader
(
gguf_path
)
...
...
pyproject.toml
View file @
74bb7fdc
...
@@ -16,7 +16,7 @@ dynamic = ["version"]
...
@@ -16,7 +16,7 @@ dynamic = ["version"]
dependencies
=
[
dependencies
=
[
"torch >= 2.3.0"
,
"torch >= 2.3.0"
,
"transformers == 4.
43.2
"
,
"transformers == 4.
51.3
"
,
"fastapi >= 0.111.0"
,
"fastapi >= 0.111.0"
,
"uvicorn >= 0.30.1"
,
"uvicorn >= 0.30.1"
,
"langchain >= 0.2.0"
,
"langchain >= 0.2.0"
,
...
...
requirements-local_chat.txt
View file @
74bb7fdc
fire
fire
transformers==4.
43.2
transformers==4.
51.3
numpy
numpy
torch>=2.3.0
torch>=2.3.0
packaging
packaging
...
...
custom_flashinfer
@
af4259e8
Compare
fd94393f
...
af4259e8
Subproject commit
fd94393fb5b8ba8bae9c0bd6ab1c2a429d81ac76
Subproject commit
af4259e8a33f095b419d1fd1733a50b22fc84c49
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment