Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
25cee581
Commit
25cee581
authored
Mar 31, 2025
by
Atream
Browse files
add balance-serve, support concurrence
parent
8d0292aa
Changes
196
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1495 additions
and
88 deletions
+1495
-88
ktransformers/operators/RoPE.py
ktransformers/operators/RoPE.py
+53
-0
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+91
-1
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+250
-34
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+106
-11
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+1
-1
ktransformers/operators/layernorm.py
ktransformers/operators/layernorm.py
+78
-0
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+161
-11
ktransformers/operators/mlp.py
ktransformers/operators/mlp.py
+23
-0
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
...ormers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml
...rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml
+90
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
...s/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
+4
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
...ize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+3
-3
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
...rmers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
+92
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
...mers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
+94
-0
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
+1
-1
ktransformers/server/args.py
ktransformers/server/args.py
+27
-9
ktransformers/server/backend/args.py
ktransformers/server/backend/args.py
+1
-10
ktransformers/server/backend/context_manager.py
ktransformers/server/backend/context_manager.py
+12
-1
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+406
-0
No files found.
ktransformers/operators/RoPE.py
View file @
25cee581
...
...
@@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding(
self
.
orig_module
.
rope_type
,
self
.
orig_module
.
config
,
)
class
RotaryEmbeddingV4
(
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
generate_device
:
str
=
"cuda"
,
prefill_device
:
str
=
"cuda"
,
**
kwargs
,
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
self
.
generate_device
=
generate_device
self
.
prefill_device
=
prefill_device
@
torch
.
no_grad
()
def
forward
(
self
,
x
,
position_ids
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type
=
x
.
device
.
type
device_type
=
device_type
if
isinstance
(
device_type
,
str
)
and
device_type
!=
"mps"
else
"cpu"
with
torch
.
autocast
(
device_type
=
device_type
,
enabled
=
False
):
freqs
=
(
inv_freq_expanded
.
float
()
@
position_ids_expanded
.
float
()).
transpose
(
1
,
2
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
emb
.
cos
()
sin
=
emb
.
sin
()
return
cos
.
to
(
dtype
=
x
.
dtype
),
sin
.
to
(
dtype
=
x
.
dtype
)
def
load
(
self
):
self
.
_init
(
dim
=
self
.
config
.
qk_rope_head_dim
,
max_position_embeddings
=
self
.
config
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
device
=
self
.
device
,
)
def
_init
(
self
,
dim
,
max_position_embeddings
,
base
,
device
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
self
.
dim
))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self
.
max_seq_len_cached
=
max_position_embeddings
\ No newline at end of file
ktransformers/operators/attention.py
View file @
25cee581
...
...
@@ -32,7 +32,8 @@ import os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
from
flashinfer.mla
import
BatchMLAPagedAttentionWrapper
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
logger
=
logging
.
getLogger
(
"attention"
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
...
...
@@ -759,3 +760,92 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
flashinfer_attn
(
BaseInjectedModule
,
DeepseekV2Attention
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
kv_b_proj
=
self
.
kv_b_proj
.
weight
.
view
(
self
.
num_heads
,
-
1
,
self
.
kv_lora_rank
)
q_absorb
=
kv_b_proj
[:,
:
self
.
qk_nope_head_dim
,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
out_absorb
=
kv_b_proj
[:,
self
.
qk_nope_head_dim
:,
:].
reshape
(
-
1
,
self
.
kv_lora_rank
)
self
.
q_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
qk_nope_head_dim
,
bias
=
False
,
dtype
=
q_absorb
.
dtype
,
device
=
q_absorb
.
device
)
self
.
q_absorb
.
weight
.
data
=
q_absorb
self
.
out_absorb
=
nn
.
Linear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
self
.
v_head_dim
,
bias
=
False
,
dtype
=
out_absorb
.
dtype
,
device
=
out_absorb
.
device
)
self
.
out_absorb
.
weight
.
data
=
out_absorb
#del self.orig_module.kv_b_proj
q_absorb
=
self
.
q_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
qk_nope_head_dim
,
self
.
kv_lora_rank
)
out_absorb
=
self
.
out_absorb
.
weight
.
view
(
self
.
num_heads
,
self
.
v_head_dim
,
self
.
kv_lora_rank
)
return
q_absorb
,
out_absorb
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KDeepSeekV3Cache
,
position_ids
:
torch
.
Tensor
,
wrapper
:
BatchMLAPagedAttentionWrapper
,
num_tokens_tensors
:
torch
.
Tensor
,
page_idx
:
torch
.
Tensor
,
page_offset
:
torch
.
Tensor
,
):
q_len
,
_
=
hidden_states
.
size
()
if
self
.
q_lora_rank
is
None
:
q
=
self
.
q_proj
(
hidden_states
,
num_tokens_tensors
)
else
:
q
=
self
.
q_b_proj
(
self
.
q_a_layernorm
(
self
.
q_a_proj
(
hidden_states
,
num_tokens_tensors
),
num_tokens_tensors
),
num_tokens_tensors
)
q
=
q
.
view
(
q_len
,
self
.
num_heads
,
self
.
q_head_dim
)
q_nope
,
q_pe
=
torch
.
split
(
q
,
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
self
.
kv_a_proj_with_mqa
(
hidden_states
,
num_tokens_tensors
)
compressed_kv
,
k_pe
=
torch
.
split
(
compressed_kv
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
compressed_kv
=
compressed_kv
.
contiguous
()
compressed_kv
=
self
.
kv_a_layernorm
(
compressed_kv
,
num_tokens_tensors
)
k_pe
=
k_pe
.
view
(
q_len
,
1
,
self
.
qk_rope_head_dim
)
compressed_kv
=
compressed_kv
.
view
(
q_len
,
1
,
self
.
kv_lora_rank
)
cos
,
sin
=
self
.
rotary_emb
(
q_pe
,
position_ids
.
unsqueeze
(
0
))
q_pe
,
k_pe
=
apply_rotary_pos_emb
(
q_pe
.
unsqueeze
(
0
),
k_pe
.
unsqueeze
(
0
),
cos
,
sin
,
unsqueeze_dim
=
2
)
q_pe
=
q_pe
.
squeeze
(
0
)
if
kv_cache
is
not
None
:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"page_idx"
:
page_idx
,
"page_offset"
:
page_offset
}
# Specific to RoPE models
compressed_kv_with_k_pe
=
kv_cache
.
update
(
compressed_kv
.
unsqueeze
(
0
),
k_pe
,
self
.
layer_idx
,
page_idx
,
page_offset
,
cache_kwargs
)
compressed_kv
=
compressed_kv_with_k_pe
[:,
:,
:,
:
self
.
kv_lora_rank
].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
kv_lora_rank
)
k_pe
=
compressed_kv_with_k_pe
[:,
:,
:,
self
.
kv_lora_rank
:].
view
(
-
1
,
kv_cache
.
page_size
,
self
.
qk_rope_head_dim
)
q_absorb
,
out_absorb
=
self
.
get_absorbed
()
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
0
,
1
)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
attn_output
=
attn_output
.
transpose
(
0
,
1
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# [self.num_heads, q_len, self.v_head_dim]
attn_output
=
attn_output
.
transpose
(
0
,
1
)
attn_output
=
attn_output
.
reshape
(
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
,
num_tokens_tensors
)
return
attn_output
\ No newline at end of file
ktransformers/operators/experts.py
View file @
25cee581
...
...
@@ -37,6 +37,10 @@ import time
from
ktransformers.operators.cpuinfer
import
CPUInfer
def
deduplicate_and_sort
(
lst
):
return
sorted
(
set
(
lst
))
#cuda_graphs = [Config().chunk_size]
cuda_graphs
=
deduplicate_and_sort
([
1
,
2
,
3
,
Config
().
max_batch_size
,
64
,
Config
().
chunk_size
])
# class Base(BaseInjectedModule, ABC):
class
KExpertsBase
(
ABC
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
device
:
str
=
"cuda"
,
**
kwargs
):
...
...
@@ -112,6 +116,7 @@ class KExpertsBase(ABC):
tensors
[
k
]
=
self
.
gguf_loader
.
load_gguf_tensor
(
key
+
k
,
device
=
device
)
return
tensors
class
KExpertsCPU
(
KExpertsBase
):
input_tensor_cpu
:
Tensor
=
None
expert_ids_cpu
:
Tensor
=
None
...
...
@@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase):
output_cpu
:
Tensor
=
None
output_gpu_map
:
dict
=
{}
# Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#
gguf_loader:GGUFLoader = None
CPU_INFER
=
None
#
@TODO add yaml
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
def
__init__
(
self
,
key
:
str
,
...
...
@@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase):
**
kwargs
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
if
KExpertsCPU
.
CPU_INFER
is
None
:
KExpertsCPU
.
CPU_INFER
=
CPUInfer
(
Config
().
cpu_infer
)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self
.
gguf_loader
=
gguf_loader
assert
device
.
lower
()
==
"cpu"
,
"KExpertsCPU can only be loaded on CPU"
self
.
n_routed_experts
=
n_routed_experts
self
.
out_device
=
out_device
...
...
@@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr
=
ctypes
.
addressof
(
ctypes
.
cast
(
self
.
down
.
ctypes
.
data
,
ctypes
.
POINTER
(
ctypes
.
c_uint64
)).
contents
)
#print(self.gate_type, self.up_type, self.down_type)
#
print(self.gate_
q
type, self.up_
q
type, self.down_
q
type)
n_routed_experts
=
self
.
n_routed_experts
# n_routed_experts = len(self.orig_module)
moe_config
=
MOEConfig
(
...
...
@@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase):
self
.
cpu_infer
.
submit
(
self
.
moe
.
warm_up
())
self
.
cpu_infer
.
sync
()
if
self
.
out_device
not
in
KExpertsCPU
.
output_gpu_map
:
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
]
=
torch
.
zeros
((
self
.
config
.
hidden_size
),
device
=
self
.
out_device
)
if
isinstance
(
cuda_graphs
,
list
):
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
]
=
[
torch
.
zeros
((
cuda_graphs
[
i
],
self
.
config
.
hidden_size
),
device
=
self
.
out_device
)
for
i
in
range
(
len
(
cuda_graphs
))]
else
:
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
]
=
torch
.
zeros
((
cuda_graphs
,
self
.
config
.
hidden_size
),
device
=
self
.
out_device
)
if
KExpertsCPU
.
input_tensor_cpu
==
None
:
KExpertsCPU
.
input_tensor_cpu
=
torch
.
zeros
((
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
)
KExpertsCPU
.
expert_ids_cpu
=
torch
.
zeros
((
num_experts_per_tok
),
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
True
)
KExpertsCPU
.
weights_cpu
=
torch
.
zeros
((
num_experts_per_tok
),
device
=
"cpu"
,
dtype
=
torch
.
float32
,
pin_memory
=
True
)
KExpertsCPU
.
output_cpu
=
torch
.
zeros
((
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
,
dtype
=
torch
.
bfloat16
)
if
isinstance
(
cuda_graphs
,
list
):
KExpertsCPU
.
input_tensor_cpu
=
[
torch
.
zeros
((
cuda_graphs
[
i
],
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
)
for
i
in
range
(
len
(
cuda_graphs
))]
KExpertsCPU
.
expert_ids_cpu
=
[
torch
.
zeros
((
cuda_graphs
[
i
],
num_experts_per_tok
),
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
True
)
for
i
in
range
(
len
(
cuda_graphs
))]
KExpertsCPU
.
weights_cpu
=
[
torch
.
zeros
((
cuda_graphs
[
i
],
num_experts_per_tok
),
device
=
"cpu"
,
dtype
=
torch
.
float32
,
pin_memory
=
True
)
for
i
in
range
(
len
(
cuda_graphs
))]
KExpertsCPU
.
output_cpu
=
[
torch
.
zeros
((
cuda_graphs
[
i
],
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
,
dtype
=
torch
.
bfloat16
)
for
i
in
range
(
len
(
cuda_graphs
))]
KExpertsCPU
.
bsz_tensor_cpu
=
[
torch
.
zeros
((
1
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
for
i
in
range
(
len
(
cuda_graphs
))]
else
:
KExpertsCPU
.
input_tensor_cpu
=
torch
.
zeros
((
cuda_graphs
,
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
)
KExpertsCPU
.
expert_ids_cpu
=
torch
.
zeros
((
cuda_graphs
,
num_experts_per_tok
),
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
True
)
KExpertsCPU
.
weights_cpu
=
torch
.
zeros
((
cuda_graphs
,
num_experts_per_tok
),
device
=
"cpu"
,
dtype
=
torch
.
float32
,
pin_memory
=
True
)
KExpertsCPU
.
output_cpu
=
torch
.
zeros
((
cuda_graphs
,
self
.
config
.
hidden_size
),
device
=
"cpu"
,
pin_memory
=
True
,
dtype
=
torch
.
bfloat16
)
KExpertsCPU
.
bsz_tensor_cpu
=
torch
.
zeros
((
1
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
def
submit_for_one_decode
(
self
,
input_tensor
,
expert_ids
,
weights
):
KExpertsCPU
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
KExpertsCPU
.
expert_ids_cpu
.
copy_
(
expert_ids
,
non_blocking
=
True
)
KExpertsCPU
.
weights_cpu
.
copy_
(
weights
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
self
.
out_device
).
cuda_stream
,
self
.
moe
.
forward
(
1
,
expert_ids
.
size
(
0
),
KExpertsCPU
.
expert_ids_cpu
.
data_ptr
(),
KExpertsCPU
.
weights_cpu
.
data_ptr
(),
KExpertsCPU
.
input_tensor_cpu
.
data_ptr
(),
KExpertsCPU
.
output_cpu
.
data_ptr
()))
def
sync_for_one_decode
(
self
):
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
self
.
out_device
).
cuda_stream
)
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
].
copy_
(
KExpertsCPU
.
output_cpu
,
non_blocking
=
True
)
return
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
]
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
):
# generate, capture and run cuda graph
# print(expert_ids)
if
input_tensor
.
size
(
0
)
==
1
and
torch
.
cuda
.
is_current_stream_capturing
():
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
def
submit_for_one_decode
(
self
,
input_tensor
,
expert_ids
,
weights
,
bsz_tensor
=
None
,
cuda_graph_idx
=
0
):
if
bsz_tensor
is
None
:
bsz_tensor
=
torch
.
ones
(
1
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
int32
)
if
cuda_graph_idx
!=
-
1
:
KExpertsCPU
.
input_tensor_cpu
[
cuda_graph_idx
].
copy_
(
input_tensor
,
non_blocking
=
True
)
KExpertsCPU
.
expert_ids_cpu
[
cuda_graph_idx
].
copy_
(
expert_ids
,
non_blocking
=
True
)
KExpertsCPU
.
weights_cpu
[
cuda_graph_idx
].
copy_
(
weights
,
non_blocking
=
True
)
KExpertsCPU
.
bsz_tensor_cpu
[
cuda_graph_idx
].
copy_
(
bsz_tensor
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
self
.
out_device
).
cuda_stream
,
self
.
moe
.
forward
(
1
,
expert_ids
.
size
(
-
1
),
KExpertsCPU
.
expert_ids_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
weights_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
input_tensor_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
output_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
bsz_tensor_cpu
[
cuda_graph_idx
].
data_ptr
()))
else
:
KExpertsCPU
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
KExpertsCPU
.
expert_ids_cpu
.
copy_
(
expert_ids
,
non_blocking
=
True
)
KExpertsCPU
.
weights_cpu
.
copy_
(
weights
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
(
1
,
expert_ids
.
size
(
1
),
KExpertsCPU
.
expert_ids_cpu
.
data_ptr
(),
KExpertsCPU
.
weights_cpu
.
data_ptr
(),
KExpertsCPU
.
input_tensor_cpu
.
data_ptr
(),
KExpertsCPU
.
output_cpu
.
data_ptr
()))
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
KExpertsCPU
.
bsz_tensor_cpu
.
copy_
(
bsz_tensor
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
self
.
out_device
).
cuda_stream
,
self
.
moe
.
forward
(
1
,
expert_ids
.
size
(
-
1
),
KExpertsCPU
.
expert_ids_cpu
.
data_ptr
(),
KExpertsCPU
.
weights_cpu
.
data_ptr
(),
KExpertsCPU
.
input_tensor_cpu
.
data_ptr
(),
KExpertsCPU
.
output_cpu
.
data_ptr
(),
KExpertsCPU
.
bsz_tensor_cpu
.
data_ptr
()))
def
sync_for_one_decode
(
self
,
cuda_graph_idx
=
0
):
if
cuda_graph_idx
!=
-
1
:
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
self
.
out_device
).
cuda_stream
)
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
][
cuda_graph_idx
].
copy_
(
KExpertsCPU
.
output_cpu
[
cuda_graph_idx
],
non_blocking
=
True
)
return
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
][
cuda_graph_idx
]
else
:
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
(
self
.
out_device
).
cuda_stream
)
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
].
copy_
(
KExpertsCPU
.
output_cpu
,
non_blocking
=
True
)
return
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
]
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
,
bsz_tensor
=
None
,
cuda_graph_idx
=
0
):
# generate, capture and run cuda graph
# print(expert_ids)
if
bsz_tensor
is
None
:
bsz_tensor
=
torch
.
tensor
([
input_tensor
.
size
(
0
)],
device
=
input_tensor
.
device
,
dtype
=
torch
.
int32
)
if
torch
.
cuda
.
is_current_stream_capturing
():
if
cuda_graph_idx
!=
-
1
:
KExpertsCPU
.
input_tensor_cpu
[
cuda_graph_idx
].
copy_
(
input_tensor
,
non_blocking
=
True
)
KExpertsCPU
.
expert_ids_cpu
[
cuda_graph_idx
].
copy_
(
expert_ids
,
non_blocking
=
True
)
KExpertsCPU
.
weights_cpu
[
cuda_graph_idx
].
copy_
(
weights
,
non_blocking
=
True
)
KExpertsCPU
.
bsz_tensor_cpu
[
cuda_graph_idx
].
copy_
(
bsz_tensor
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
(
expert_ids
.
size
(
0
),
expert_ids
.
size
(
-
1
),
KExpertsCPU
.
expert_ids_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
weights_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
input_tensor_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
output_cpu
[
cuda_graph_idx
].
data_ptr
(),
KExpertsCPU
.
bsz_tensor_cpu
[
cuda_graph_idx
].
data_ptr
()))
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
][
cuda_graph_idx
].
copy_
(
KExpertsCPU
.
output_cpu
[
cuda_graph_idx
],
non_blocking
=
True
)
return
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
][
cuda_graph_idx
]
else
:
KExpertsCPU
.
input_tensor_cpu
.
copy_
(
input_tensor
,
non_blocking
=
True
)
KExpertsCPU
.
expert_ids_cpu
.
copy_
(
expert_ids
,
non_blocking
=
True
)
KExpertsCPU
.
weights_cpu
.
copy_
(
weights
,
non_blocking
=
True
)
KExpertsCPU
.
bsz_tensor_cpu
.
copy_
(
bsz_tensor
,
non_blocking
=
True
)
self
.
cpu_infer
.
submit_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
,
self
.
moe
.
forward
(
expert_ids
.
size
(
0
),
expert_ids
.
size
(
-
1
),
KExpertsCPU
.
expert_ids_cpu
.
data_ptr
(),
KExpertsCPU
.
weights_cpu
.
data_ptr
(),
KExpertsCPU
.
input_tensor_cpu
.
data_ptr
(),
KExpertsCPU
.
output_cpu
.
data_ptr
(),
KExpertsCPU
.
bsz_tensor_cpu
.
data_ptr
()))
self
.
cpu_infer
.
sync_with_cuda_stream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
].
copy_
(
KExpertsCPU
.
output_cpu
,
non_blocking
=
True
)
return
KExpertsCPU
.
output_gpu_map
[
self
.
out_device
]
else
:
input_tensor
=
input_tensor
.
contiguous
().
cpu
()
expert_ids
=
expert_ids
.
contiguous
().
cpu
()
weights
=
weights
.
contiguous
().
to
(
torch
.
float32
).
cpu
()
bsz_tensor
=
bsz_tensor
.
contiguous
().
cpu
()
output
=
torch
.
empty_like
(
input_tensor
).
contiguous
()
self
.
cpu_infer
.
submit
(
self
.
moe
.
forward
(
expert_ids
.
size
(
0
),
expert_ids
.
size
(
1
),
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input_tensor
.
data_ptr
(),
output
.
data_ptr
()))
self
.
cpu_infer
.
submit
(
self
.
moe
.
forward
(
expert_ids
.
size
(
0
),
expert_ids
.
size
(
1
),
expert_ids
.
data_ptr
(),
weights
.
data_ptr
(),
input_tensor
.
data_ptr
(),
output
.
data_ptr
()
,
bsz_tensor
.
data_ptr
()
))
self
.
cpu_infer
.
sync
()
return
output
.
to
(
device
=
object
.
__getattribute__
(
self
,
"out_device"
))
...
...
@@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y
+=
y_
return
y
@
torch
.
no_grad
()
def
moe_kexperts
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
)
...
...
@@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
# the `top_x` tensor here.
final_hidden_states
.
index_add_
(
0
,
top_x
,
current_hidden_states
.
to
(
hidden_states_cpu
.
dtype
))
return
final_hidden_states
\ No newline at end of file
return
final_hidden_states
class
KDeepseekV3MoEV2
(
BaseInjectedModule
,
DeepseekV3MoE
):
def
forward
(
self
,
hidden_states
,
bsz_tensor
,
cuda_graph_idx
=
0
):
identity
=
hidden_states
orig_shape
=
hidden_states
.
shape
sequence_length
=
orig_shape
[
1
]
topk_idx
,
topk_weight
=
self
.
gate
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
# only for generate phase
if
hasattr
(
self
.
experts
.
generate_experts
,
"submit_for_one_decode"
)
and
torch
.
cuda
.
is_current_stream_capturing
():
# TODO: this branch cause jit bug
self
.
experts
.
generate_experts
.
submit_for_one_decode
(
hidden_states
,
topk_idx
,
topk_weight
,
bsz_tensor
,
cuda_graph_idx
)
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
,
bsz_tensor
).
squeeze
(
0
)
y
=
self
.
experts
.
generate_experts
.
sync_for_one_decode
(
cuda_graph_idx
).
unsqueeze
(
0
)
y
+=
y_
y
.
resize_
(
*
orig_shape
)
return
y
if
self
.
config
.
n_shared_experts
is
not
None
:
y_
=
self
.
shared_experts
(
identity
,
bsz_tensor
).
squeeze
(
0
)
if
isinstance
(
self
.
experts
,
KExpertsBase
):
y
=
self
.
moe_on_cpuinfer
(
hidden_states
,
topk_idx
,
topk_weight
,
bsz_tensor
,
cuda_graph_idx
).
view
(
*
orig_shape
).
to
(
device
=
hidden_states
.
device
)
elif
hidden_states
.
size
(
0
)
>
10
:
# TODO may bugs here
y
=
(
self
.
moe_infer
(
hidden_states
,
topk_idx
,
topk_weight
)
.
view
(
*
orig_shape
)
.
to
(
device
=
hidden_states
.
device
)
)
else
:
# TODO may bugs here
y
=
(
self
.
moe_infer_simple
(
hidden_states
,
topk_idx
,
topk_weight
)
.
view
(
*
orig_shape
)
.
to
(
device
=
hidden_states
.
device
)
)
if
self
.
config
.
n_shared_experts
is
not
None
:
y
+=
y_
return
y
@
torch
.
no_grad
()
def
moe_on_cpuinfer
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
bsz_tensor
,
cuda_graph_idx
=
0
)
->
torch
.
Tensor
:
outs
=
torch
.
empty_like
(
x
)
outs
=
self
.
experts
(
x
,
topk_ids
,
topk_weight
,
bsz_tensor
,
cuda_graph_idx
)
return
outs
@
torch
.
no_grad
()
# TODO may bugs here
def
moe_infer_simple
(
self
,
x
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs
=
torch
.
zeros_like
(
x
)
for
token_idx
in
range
(
topk_ids
.
size
(
0
)):
for
expert_idx
in
range
(
topk_ids
.
size
(
1
)):
expert
=
self
.
experts
[
topk_ids
[
token_idx
,
expert_idx
]]
outs
[
token_idx
]
+=
(
expert
.
forward
(
x
[
token_idx
])
*
topk_weight
[
token_idx
,
expert_idx
]
)
return
outs
@
torch
.
no_grad
()
# TODO may bugs here
def
moe_infer
(
self
,
x
,
topk_ids
,
topk_weight
):
cnts
=
topk_ids
.
new_zeros
((
topk_ids
.
shape
[
0
],
len
(
self
.
experts
)))
cnts
.
scatter_
(
1
,
topk_ids
,
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
topk_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
expert
=
self
.
experts
[
i
+
self
.
ep_rank
*
self
.
experts_per_rank
]
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
expert_out
=
expert
.
forward
(
tokens_for_this_expert
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weight
.
dtype
)
.
mul_
(
topk_weight
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
final_out
class
KTransformersExpertsV2
(
BaseInjectedModule
,
KExpertsBase
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
# device: str = "cuda",
prefill_device
:
str
=
"cuda"
,
prefill_op
:
str
|
None
=
"KExpertsTorch"
,
generate_device
:
str
=
"cpu"
,
generate_op
:
str
|
None
=
"KExpertsCPU"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
KExpertsBase
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
generate_device
,
**
kwargs
)
if
generate_op
is
not
None
:
self
.
generate_experts
=
EXPERTS_MAP
[
generate_op
](
key
,
gguf_loader
,
config
,
len
(
orig_module
),
device
=
generate_device
,
**
kwargs
)
else
:
self
.
generate_experts
=
None
if
prefill_op
is
not
None
:
self
.
prefill_experts
=
EXPERTS_MAP
[
prefill_op
](
key
,
gguf_loader
,
config
,
len
(
orig_module
),
device
=
prefill_device
,
**
kwargs
)
else
:
self
.
prefill_experts
=
None
self
.
gpu_mlp_type
=
prefill_op
self
.
cpu_mlp_type
=
generate_op
self
.
mode
=
InferenceState
.
UNLOAD
def
load
(
self
,
w
:
dict
=
None
,
mode
:
InferenceState
=
None
,
warmup
:
bool
=
True
):
# TODO support w as input
if
not
mode
:
mode
=
InferenceState
.
GENERATE
if
mode
==
InferenceState
.
GENERATE
:
self
.
prefill_experts
.
unload
()
self
.
generate_experts
.
load
(
w
,
warmup
=
warmup
)
self
.
device
=
self
.
generate_experts
.
device
self
.
mode
=
mode
elif
mode
==
InferenceState
.
PREFILL
:
self
.
generate_experts
.
unload
()
self
.
prefill_experts
.
load
(
w
,
warmup
=
warmup
)
self
.
device
=
self
.
prefill_experts
.
device
self
.
mode
=
mode
elif
mode
==
InferenceState
.
UNLOAD
:
self
.
unload
()
self
.
mode
=
mode
self
.
device
=
self
.
generate_experts
.
device
else
:
raise
ValueError
(
"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD"
)
def
unload
(
self
):
if
self
.
generate_experts
is
not
None
:
self
.
generate_experts
.
unload
()
if
self
.
prefill_experts
is
not
None
:
self
.
prefill_experts
.
unload
()
self
.
device
=
self
.
generate_experts
.
device
def
forward
(
self
,
input_tensor
,
expert_ids
,
weights
,
bsz_tensor
,
cuda_graph_idx
=
0
):
if
self
.
mode
==
InferenceState
.
GENERATE
:
assert
self
.
generate_experts
is
not
None
,
"generate_experts is None"
return
self
.
generate_experts
.
forward
(
input_tensor
,
expert_ids
,
weights
,
bsz_tensor
,
cuda_graph_idx
)
elif
self
.
mode
==
InferenceState
.
PREFILL
:
assert
self
.
prefill_experts
is
not
None
,
"prefill_experts is None"
return
self
.
prefill_experts
.
forward
(
input_tensor
,
expert_ids
,
weights
,
bsz_tensor
,
cuda_graph_idx
)
else
:
raise
ValueError
(
"load or set_inference_mode before forward"
)
def
set_inference_mode
(
self
,
mode
:
InferenceState
):
if
mode
==
InferenceState
.
GENERATE
:
self
.
load
(
mode
=
InferenceState
.
GENERATE
,
warmup
=
False
)
elif
mode
==
InferenceState
.
PREFILL
:
self
.
load
(
mode
=
InferenceState
.
PREFILL
,
warmup
=
False
)
elif
mode
==
InferenceState
.
UNLOAD
:
self
.
unload
()
else
:
raise
ValueError
(
"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD"
)
ktransformers/operators/flashinfer_wrapper.py
View file @
25cee581
...
...
@@ -86,6 +86,7 @@ class MLAWrapper():
self
.
qo_indptr_buf
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indptr_buf
=
torch
.
empty
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_indices_buf
=
torch
.
empty
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
batch_size_tensor_buf
=
torch
.
tensor
([
self
.
max_batch_size
],
dtype
=
torch
.
int32
,
device
=
device
)
self
.
kv_len_arr_buf
=
torch
.
empty
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
self
.
qo_indptr_buf
=
None
...
...
@@ -94,19 +95,22 @@ class MLAWrapper():
self
.
kv_len_arr_buf
=
None
self
.
wrapper
=
flashinfer
.
mla
.
BatchMLAPagedAttentionWrapper
(
self
.
float_workspace_buffer
,
use_cuda_graph
=
False
,
use_cuda_graph
=
use_cuda_graph
,
qo_indptr
=
self
.
qo_indptr_buf
,
kv_indptr
=
self
.
kv_indptr_buf
,
kv_indices
=
self
.
kv_indices_buf
,
kv_len_arr
=
self
.
kv_len_arr_buf
,
bsz_tensor
=
self
.
batch_size_tensor_buf
)
self
.
need_plan
=
True
def
plan
(
self
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
bsz_tensor
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
...
...
@@ -138,6 +142,7 @@ class MLAWrapper():
sm_scale
,
q_data_type
,
kv_data_type
,
bsz_tensor
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
...
...
@@ -240,16 +245,17 @@ if __name__ == "__main__":
#checksame()
#exit(0)
max_batch_size
=
1
max_pages
=
64
max_batch_size
=
2
max_batch_tokens
=
256
max_pages
=
128
page_size
=
64
num_heads
=
128
# warm-up
kv_len
=
4023
q_len
=
1
q_nope_buf
=
torch
.
randn
((
q_l
en
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe_buf
=
torch
.
randn
((
q_l
en
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope_buf
=
torch
.
randn
((
max_batch_tok
en
s
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe_buf
=
torch
.
randn
((
max_batch_tok
en
s
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
kv_buf
=
torch
.
randn
((
max_pages
,
page_size
,
576
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
,
k_pe
=
torch
.
split
(
kv_buf
,
[
512
,
64
],
dim
=-
1
)
...
...
@@ -260,13 +266,19 @@ if __name__ == "__main__":
max_pages
,
)
used_pages
=
(
kv_len
+
page_size
-
1
)
//
page_size
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
tensor
([
0
,
used_pages
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
empty
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
[:
used_pages
]
=
torch
.
arange
(
0
,
used_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
bsz_tensor
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
bsz_tensor
,
128
,
512
,
64
,
...
...
@@ -276,14 +288,98 @@ if __name__ == "__main__":
torch
.
bfloat16
,
)
attn_output
=
wrapper
.
run
(
q_nope_buf
,
q_pe_buf
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope_buf
[:
q_len
]
,
q_pe_buf
[:
q_len
]
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
attn_output
=
wrapper
.
run
(
q_nope_buf
,
q_pe_buf
,
ckv
,
k_pe
)
graph
.
replay
()
q
=
torch
.
cat
([
q_nope_buf
,
q_pe_buf
],
dim
=-
1
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
attn_ref
,
lse_ref
=
attention_ref_torch
(
1
,
q
[:
q_len
],
k
[:
kv_len
],
v
[:
kv_len
],
True
,
192
**
(
-
0.5
)
)
torch
.
testing
.
assert_close
(
attn_output
[:
q_len
],
attn_ref
,
rtol
=
5e-3
,
atol
=
5e-3
)
# warm-up finished
kv_len
=
512
q_len
=
128
pages
=
max_pages
used_pages
=
(
kv_len
+
page_size
-
1
)
//
page_size
q_nope
=
torch
.
randn
((
q_len
*
2
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope
[
q_len
:]
=
q_nope
[:
q_len
]
q_pe
=
torch
.
randn
((
q_len
*
2
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
[
q_len
:]
=
q_pe
[:
q_len
]
kv_cache
=
torch
.
randn
((
max_pages
,
page_size
,
576
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
kv_cache
[
used_pages
:
2
*
used_pages
]
=
kv_cache
[:
used_pages
]
ckv
,
k_pe
=
torch
.
split
(
kv_cache
,
[
512
,
64
],
dim
=-
1
)
kv_len_arr
=
torch
.
tensor
([
kv_len
,
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
,
q_len
*
2
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
tensor
([
0
,
used_pages
,
used_pages
*
2
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
empty
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
[:
2
*
used_pages
]
=
torch
.
arange
(
0
,
2
*
used_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
bsz_tensor
=
torch
.
tensor
([
2
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
bsz_tensor
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
q_nope_buf
.
copy_
(
q_nope
)
q_pe_buf
.
copy_
(
q_pe
)
kv_buf
[:
pages
].
copy_
(
kv_cache
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
# ref_torch
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
attn_ref
,
lse_ref
=
attention_ref_torch
(
max_batch_size
,
q
,
k
[:
2
*
kv_len
],
v
[:
2
*
kv_len
],
True
,
192
**
(
-
0.5
)
)
torch
.
testing
.
assert_close
(
attn_ref
[:
q_len
],
attn_ref
[
q_len
:
q_len
*
2
],
rtol
=
1e-9
,
atol
=
1e-9
)
torch
.
testing
.
assert_close
(
attn_output
[:
q_len
],
attn_output
[
q_len
:
q_len
*
2
],
rtol
=
1e-9
,
atol
=
1e-9
)
torch
.
testing
.
assert_close
(
attn_output
[:
q_len
],
attn_ref
[:
q_len
],
rtol
=
5e-3
,
atol
=
5e-3
)
torch
.
testing
.
assert_close
(
attn_output
[
q_len
:
q_len
*
2
],
attn_ref
[
q_len
:
q_len
*
2
],
rtol
=
5e-3
,
atol
=
5e-3
)
#torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
#torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)
exit
(
0
)
for
forward_id
in
range
(
0
,
1
):
print
(
"forward_id"
,
forward_id
)
for
layer_id
in
range
(
1
):
...
...
@@ -376,5 +472,4 @@ if __name__ == "__main__":
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print
(
"test past"
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/operators/gate.py
View file @
25cee581
...
...
@@ -249,4 +249,4 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
e_score_correction_bias
is
not
None
:
self
.
e_score_correction_bias
=
None
self
.
e_score_correction_bias
=
None
\ No newline at end of file
ktransformers/operators/layernorm.py
0 → 100644
View file @
25cee581
'''
Date: 2024-11-13 15:05:52
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:19
'''
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for normalization layers."""
import
logging
from
typing
import
Optional
,
Tuple
,
Union
from
transformers
import
PretrainedConfig
import
torch
import
torch.nn
as
nn
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3RMSNorm
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
flashinfer.norm
import
(
fused_add_rmsnorm
,
rmsnorm
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
RMSNorm
(
DeepseekV3RMSNorm
,
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
hidden_size
,
orig_module
.
variance_epsilon
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
batch_size_tensor
:
torch
.
Tensor
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
#return self.forward_native(x, residual)
if
batch_size_tensor
is
None
:
return
self
.
forward_native
(
x
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
batch_size_tensor
,
self
.
variance_epsilon
)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return
x
,
residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
batch_size_tensor
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
\ No newline at end of file
ktransformers/operators/linear.py
View file @
25cee581
...
...
@@ -15,14 +15,16 @@ import ctypes
import
torch
from
torch
import
Tensor
,
nn
import
KTransformersOps
import
vLLMMarlin
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils
import
(
MarlinWorkspace
,
marlin_quantize
,
marlin_quantize
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MIN_THREAD_K
,
GPTQ_MARLIN_MAX_PARALLEL
,
vllm_marlin_quantize
)
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
transformers.configuration_utils
import
PretrainedConfig
...
...
@@ -84,8 +86,10 @@ class KLinearBase(ABC):
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
# using safetensor_loader
tensor
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight'
)
weight_scale_inv
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight_scale_inv'
)
return
nn
.
Parameter
(
tensor
),
nn
.
Parameter
(
weight_scale_inv
)
if
key
+
'.weight_scale_inv'
in
self
.
gguf_loader
.
safetensor_loader
.
tensor_file_map
:
weight_scale_inv
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight_scale_inv'
)
return
nn
.
Parameter
(
tensor
),
nn
.
Parameter
(
weight_scale_inv
)
return
nn
.
Parameter
(
tensor
)
elif
key
+
".weight"
in
self
.
gguf_loader
.
tensor_file_map
:
if
key
+
".bias"
in
self
.
gguf_loader
.
tensor_file_map
:
...
...
@@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
self
.
weight
=
None
self
.
has_bias
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
dtype
=
x
.
dtype
out_device
=
x
.
device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
...
...
@@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
self
.
bias
=
None
class
KLinearQ8
(
KLinearBase
):
def
__init__
(
self
,
...
...
@@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
block_size
=
block_size
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
.
to
(
self
.
device
)
orig_dtype
=
x
.
dtype
x_quantized
,
scale_x
=
act_quant
(
x
,
self
.
block_size
)
...
...
@@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
self
.
weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
# TODO: merge two marlin class
class
VLinearMarlin
(
KLinearBase
):
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
num_bits
:
int
=
4
,
# 4-bit/8-bit is supported
group_size
:
int
=
64
,
# -1, 32, 64, 128
act_order
:
bool
=
False
,
is_k_full
=
True
,
**
kwargs
,
):
assert
device
.
lower
()
!=
"cpu"
,
"Marlin quantized linear only supports GPU device"
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
num_bits
=
num_bits
self
.
group_size
=
group_size
self
.
act_order
=
act_order
self
.
is_k_full
=
is_k_full
self
.
padding
=
False
self
.
orin_in_features
=
self
.
in_features
self
.
orin_out_features
=
self
.
out_features
if
self
.
in_features
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
or
self
.
out_features
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self
.
padding
=
True
self
.
in_features
=
(
self
.
in_features
+
GPTQ_MARLIN_MIN_THREAD_K
-
1
)
//
GPTQ_MARLIN_MIN_THREAD_K
*
GPTQ_MARLIN_MIN_THREAD_K
self
.
out_features
=
(
self
.
out_features
+
GPTQ_MARLIN_MIN_THREAD_N
-
1
)
//
GPTQ_MARLIN_MIN_THREAD_N
*
GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self
.
k
=
self
.
in_features
self
.
n
=
self
.
out_features
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
self
.
loaded
:
return
if
device
is
None
:
device
=
self
.
device
assert
device
.
lower
()
!=
"cpu"
,
"Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
if
isinstance
(
w
,
nn
.
Parameter
):
# pad weight
weight
=
w
.
view
(
self
.
orin_out_features
,
self
.
orin_in_features
).
T
self
.
has_bias
=
False
elif
isinstance
(
w
,
tuple
):
w
=
list
(
w
)
weight
=
w
[
0
].
view
(
self
.
orin_out_features
,
self
.
orin_in_features
).
T
self
.
bias
=
w
[
1
].
view
(
self
.
orin_out_features
)
self
.
bias
=
w
[
1
]
self
.
has_bias
=
True
else
:
raise
ValueError
(
"Invalid weight type"
)
weight
=
weight
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
if
self
.
padding
:
padded_weight
=
torch
.
zeros
(
self
.
in_features
,
self
.
out_features
,
device
=
self
.
device
)
padded_weight
[:
self
.
orin_in_features
,
:
self
.
orin_out_features
]
=
weight
weight
=
padded_weight
# Pack Marlin linear
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
weight
,
self
.
num_bits
,
self
.
group_size
,
self
.
act_order
)
self
.
workspace
=
MarlinWorkspace
(
self
.
out_features
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
self
.
device
)
self
.
weight
=
marlin_q_w
self
.
marlin_q_w
=
marlin_q_w
self
.
marlin_s
=
marlin_s
self
.
g_idx
=
g_idx
self
.
sort_indices
=
sort_indices
self
.
k
=
weight
.
shape
[
0
]
self
.
n
=
weight
.
shape
[
1
]
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
self
.
loaded
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
if
bsz_tensor
is
None
:
bsz_tensor
=
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Only support input x as BF16 and FP16
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
orig_dtype
=
x
.
dtype
x
=
x
.
reshape
(
-
1
,
orig_shape
[
-
1
])
marlin_s
=
self
.
marlin_s
.
to
(
x
.
dtype
)
sms
=
-
1
x
=
vLLMMarlin
.
gptq_marlin_gemm
(
x
,
self
.
marlin_q_w
,
marlin_s
,
self
.
g_idx
,
self
.
sort_indices
,
self
.
workspace
.
scratch
,
self
.
num_bits
,
bsz_tensor
,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x
.
shape
[
0
],
self
.
n
,
x
.
shape
[
-
1
],
sms
,
self
.
is_k_full
,
)
# x = KTransformersOps.gptq_marlin_gemm(
# x,
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
if
self
.
has_bias
:
x
=
x
+
self
.
bias
orig_shape
[
-
1
]
=
self
.
n
return
x
.
reshape
(
orig_shape
).
to
(
orig_dtype
)
def
unload
(
self
):
if
self
.
has_bias
:
self
.
bias
=
None
self
.
marlin_q_w
=
None
self
.
marlin_s
=
None
self
.
g_idx
=
None
self
.
sort_indices
=
None
self
.
workspace
=
None
class
KLinearMarlin
(
KLinearBase
):
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
...
...
@@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
self
.
n
=
weight
.
shape
[
1
]
self
.
loaded
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
=
None
,
**
kwargs
)
->
torch
.
Tensor
:
# Only support input x as BF16 and FP16
x
=
x
.
to
(
self
.
device
)
orig_shape
=
list
(
x
.
shape
)
...
...
@@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
if
self
.
w
is
not
None
:
self
.
w
=
None
if
self
.
has_bias
:
self
.
bias
=
None
self
.
bias
=
None
LINEAR_MAP
=
{
"KLinearMarlin"
:
KLinearMarlin
,
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"VLinearMarlin"
:
VLinearMarlin
,
"KLinearFP8"
:
KLinearFP8
,
"KLinearQ8"
:
KLinearQ8
,
}
...
...
@@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self
.
generate_linear
=
None
self
.
mode
=
InferenceState
.
UNLOAD
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
bsz_tensor
=
None
):
if
self
.
mode
==
InferenceState
.
PREFILL
:
assert
self
.
prefill_linear
is
not
None
,
"cpu linear is not initialized"
y
=
self
.
prefill_linear
.
forward
(
x
)
y
=
self
.
prefill_linear
.
forward
(
x
,
bsz_tensor
)
else
:
assert
self
.
generate_linear
is
not
None
,
"gpu linear is not initialized"
y
=
self
.
generate_linear
.
forward
(
x
)
y
=
self
.
generate_linear
.
forward
(
x
,
bsz_tensor
)
return
y
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
mode
:
InferenceState
=
InferenceState
.
GENERATE
):
...
...
@@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self
.
unload
()
else
:
raise
ValueError
(
"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD"
)
ktransformers/operators/mlp.py
0 → 100644
View file @
25cee581
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers
import
PretrainedConfig
import
torch.nn
as
nn
from
ktransformers.models.modeling_deepseek_v3
import
DeepseekV3MLP
class
kDeepseekV3MLP
(
DeepseekV3MLP
,
BaseInjectedModule
):
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
,
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
hidden_size
,
orig_module
.
intermediate_size
)
def
forward
(
self
,
x
,
bsz_tensor
):
down_proj
=
self
.
down_proj
(
self
.
act_fn
(
self
.
gate_proj
(
x
,
bsz_tensor
))
*
self
.
up_proj
(
x
,
bsz_tensor
),
bsz_tensor
)
return
down_proj
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
View file @
25cee581
...
...
@@ -22,7 +22,7 @@
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
kwargs
:
generate_device
:
"
c
p
u"
generate_device
:
"
cu
da
"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearMarlin"
prefill_op
:
"
KLinearTorch"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml
0 → 100644
View file @
25cee581
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoEV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace
:
class
:
ktransformers.operators.layernorm.RMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace
:
class
:
ktransformers.operators.mlp.kDeepseekV3MLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
View file @
25cee581
...
...
@@ -10,7 +10,7 @@
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.
KMoEGateDeepSeek
V3
class
:
ktransformers.operators.RoPE.
YarnRotaryEmbedding
V3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
@@ -18,7 +18,7 @@
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.
KMoEGateDeepSeek
V3
class
:
ktransformers.operators.RoPE.
YarnRotaryEmbedding
V3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
...
@@ -66,7 +66,7 @@
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
@@ -74,7 +74,7 @@
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
View file @
25cee581
...
...
@@ -10,7 +10,7 @@
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.
KMoEGateDeepSeek
V3
class
:
ktransformers.operators.RoPE.
YarnRotaryEmbedding
V3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
@@ -66,7 +66,7 @@
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
@@ -74,7 +74,7 @@
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
0 → 100644
View file @
25cee581
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoEV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace
:
class
:
ktransformers.operators.layernorm.RMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace
:
class
:
ktransformers.operators.mlp.kDeepseekV3MLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
25cee581
...
...
@@ -38,7 +38,7 @@
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
0 → 100644
View file @
25cee581
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoEV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace
:
class
:
ktransformers.operators.layernorm.RMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace
:
class
:
ktransformers.operators.mlp.kDeepseekV3MLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbeddingV4
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
View file @
25cee581
...
...
@@ -38,7 +38,7 @@
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
ktransformers/server/args.py
View file @
25cee581
import
argparse
from
ktransformers.server.backend.args
import
ConfigArgs
,
default_args
from
ktransformers.util.utils
import
get_free_ports
class
ArgumentParser
:
def
__init__
(
self
,
cfg
):
...
...
@@ -16,20 +16,18 @@ class ArgumentParser:
parser
.
add_argument
(
"--web"
,
type
=
bool
,
default
=
self
.
cfg
.
mount_web
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
self
.
cfg
.
model_name
)
parser
.
add_argument
(
"--model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
self
.
cfg
.
model_path
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
self
.
cfg
.
model_device
,
help
=
"Warning: Abandoning this parameter"
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
default
=
self
.
cfg
.
gguf_path
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_
prefill_
size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--
backend_
type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_size"
,
type
=
int
,
default
=
self
.
cfg
.
chunk_size
)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
...
...
@@ -62,7 +60,6 @@ class ArgumentParser:
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
self
.
cfg
.
repetition_penalty
)
parser
.
add_argument
(
"--frequency_penalty"
,
type
=
float
,
default
=
self
.
cfg
.
frequency_penalty
)
parser
.
add_argument
(
"--presence_penalty"
,
type
=
float
,
default
=
self
.
cfg
.
presence_penalty
)
parser
.
add_argument
(
"--max_response_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_response_tokens
)
parser
.
add_argument
(
"--response_chunk"
,
type
=
int
,
default
=
self
.
cfg
.
response_chunk
)
parser
.
add_argument
(
"--no_code_formatting"
,
type
=
bool
,
default
=
self
.
cfg
.
no_code_formatting
)
parser
.
add_argument
(
"--cache_8bit"
,
type
=
bool
,
default
=
self
.
cfg
.
cache_8bit
)
...
...
@@ -103,6 +100,18 @@ class ArgumentParser:
# local chat
parser
.
add_argument
(
"--prompt_file"
,
type
=
str
,
default
=
self
.
cfg
.
prompt_file
)
# async server
parser
.
add_argument
(
"--sched_strategy"
,
type
=
str
,
default
=
self
.
cfg
.
sched_strategy
)
# parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port)
# parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port)
# parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port)
parser
.
add_argument
(
"--page_size"
,
type
=
str
,
default
=
self
.
cfg
.
page_size
)
parser
.
add_argument
(
"--memory_gpu_only"
,
type
=
str
,
default
=
self
.
cfg
.
memory_gpu_only
)
parser
.
add_argument
(
"--utilization_percentage"
,
type
=
str
,
default
=
self
.
cfg
.
utilization_percentage
)
parser
.
add_argument
(
"--cpu_memory_size_GB"
,
type
=
str
,
default
=
self
.
cfg
.
cpu_memory_size_GB
)
args
=
parser
.
parse_args
()
if
(
args
.
model_dir
is
not
None
or
args
.
model_path
is
not
None
):
if
(
args
.
model_path
is
not
None
):
...
...
@@ -123,6 +132,15 @@ class ArgumentParser:
self
.
cfg
.
mount_web
=
args
.
web
self
.
cfg
.
server_ip
=
args
.
host
self
.
cfg
.
server_port
=
args
.
port
self
.
cfg
.
backend_type
=
args
.
type
self
.
cfg
.
user_force_think
=
args
.
force_think
args
.
gpu_memory_size
=
args
.
cache_lens
*
2
*
576
*
61
self
.
cfg
.
gpu_memory_size
=
args
.
gpu_memory_size
free_ports
=
get_free_ports
(
3
,
[
args
.
port
])
args
.
sched_port
=
free_ports
[
0
]
args
.
sched_metrics_port
=
free_ports
[
1
]
args
.
kvc2_metrics_port
=
free_ports
[
2
]
self
.
cfg
.
sched_port
=
free_ports
[
0
]
self
.
cfg
.
sched_metrics_port
=
free_ports
[
1
]
self
.
cfg
.
kvc2_metrics_port
=
free_ports
[
2
]
return
args
ktransformers/server/backend/args.py
View file @
25cee581
...
...
@@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
class
Config
:
protected_namespaces
=
()
paged
:
bool
=
Field
(
None
,
description
=
"Whether to use paged attention kv cache"
)
total_context
:
int
=
Field
(
None
,
description
=
(
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
" total to distribute dynamically over however many jobs are active at once"
),
)
max_batch_size
:
int
=
Field
(
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
)
chunk_
prefill_
size
:
int
=
Field
(
chunk_size
:
int
=
Field
(
None
,
description
=
(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
...
...
@@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
repetition_penalty
:
float
=
Field
(
None
,
description
=
"Sampler repetition penalty, default = 1.01 (1 to disable)"
)
frequency_penalty
:
float
=
Field
(
None
,
description
=
"Sampler frequency penalty, default = 0.0 (0 to disable)"
)
presence_penalty
:
float
=
Field
(
None
,
description
=
"Sampler presence penalty, default = 0.0 (0 to disable)"
)
max_response_tokens
:
int
=
Field
(
None
,
description
=
"Max tokens per response, default = 1000"
)
response_chunk
:
int
=
Field
(
None
,
description
=
"Space to reserve in context for reply, default = 250"
)
no_code_formatting
:
bool
=
Field
(
None
,
description
=
"Disable code formatting/syntax highlighting"
)
cache_8bit
:
bool
=
Field
(
None
,
description
=
"Use 8-bit (FP8) cache"
)
...
...
ktransformers/server/backend/context_manager.py
View file @
25cee581
...
...
@@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersThreadContext
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaThreadContext
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
class
ThreadContextManager
:
lock
:
Lock
threads_context
:
Dict
[
ObjectID
,
ThreadContext
]
...
...
@@ -36,7 +38,16 @@ class ThreadContextManager:
elif
isinstance
(
self
.
interface
,
TransformersInterface
):
new_context
=
TransformersThreadContext
(
run
,
self
.
interface
)
else
:
raise
NotImplementedError
from
ktransformers.server.backend.interfaces.balance_serve
import
BalanceServeThreadContext
from
ktransformers.server.backend.interfaces.balance_serve
import
BalanceServeInterface
if
isinstance
(
self
.
interface
,
BalanceServeInterface
):
new_context
=
BalanceServeThreadContext
(
run
,
self
.
interface
)
else
:
raise
NotImplementedError
# elif isinstance(self.interface, BalanceServeInterface):
# new_context = BalanceServeThreadContext(run, self.interface)
# else:
# raise NotImplementedError
self
.
threads_context
[
run
.
thread_id
]
=
new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re
=
self
.
threads_context
[
run
.
thread_id
]
...
...
ktransformers/server/backend/interfaces/balance_serve.py
0 → 100644
View file @
25cee581
from
typing
import
Any
,
AsyncIterator
,
List
,
Optional
,
Set
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
GenerationConfig
,
StaticCache
,
AutoModelForCausalLM
,
BitsAndBytesConfig
,
)
from
ktransformers.server.config.config
import
Config
from
..base
import
ThreadContext
,
BackendInterfaceBase
import
torch
from
ktransformers.server.backend.interfaces.transformers
import
(
ConfigArgs
,
default_args
,
TextStreamer
,
)
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.config.log
import
logger
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.server.balance_serve.inference.model_runner
import
ModelRunner
from
ktransformers.server.balance_serve.inference.sampling.sampler
import
Sampler
,
SamplingOptions
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
from
ktransformers.server.balance_serve.inference.forward_batch
import
ForwardBatchInput
,
ForwardBatchOutput
from
ktransformers.server.balance_serve.sched_rpc
import
SchedulerClient
from
ktransformers.server.balance_serve.settings
import
sched_ext
from
torch.multiprocessing
import
Queue
import
torch.multiprocessing
as
mp
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.utils.multi_timer
import
Profiler
import
zmq
import
time
import
queue
import
tempfile
import
asyncio
import
threading
from
contextlib
import
asynccontextmanager
from
fastapi
import
FastAPI
,
Request
import
os
ktransformer_rules_dir
=
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
)
default_optimize_rules
=
{
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat-serve.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct-serve.yaml"
,
}
async
def
chat_stream
(
queue
:
asyncio
.
Queue
,
tokenizer
:
AutoTokenizer
):
streamer
=
TextStreamer
(
tokenizer
)
while
True
:
token
=
await
queue
.
get
()
#print(f"Got token: {token}")
if
token
is
None
:
# str = f'{token}\n\n'
# str = model.tokenizer.decode(token)
s
=
streamer
.
end
()
if
s
is
not
None
:
yield
s
break
# str = model.tokenizer.decode(token)
yield
streamer
.
put
(
token
)
def
fill_generated_tokens
(
query_updates
:
list
[
sched_ext
.
QueryUpdate
],
generated_tokens
:
torch
.
Tensor
,
query_manager
:
QueryManager
=
None
):
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
for
i
in
range
(
generated_tokens
.
size
(
0
)):
print
(
generated_tokens
[
i
].
item
())
query_updates
[
i
].
generated_token
=
generated_tokens
[
i
].
item
()
if
not
query_manager
.
query_map
[
query_updates
[
i
].
id
].
is_prefill
:
pos
=
query_updates
[
i
].
active_position
query_manager
.
query_map
[
query_updates
[
i
].
id
].
query_tokens
[
pos
]
=
generated_tokens
[
i
]
def
report_last_time_performance
(
profiler
:
Profiler
):
try
:
tokenize_time
=
profiler
.
get_timer_sec
(
'tokenize'
)
prefill_time
=
profiler
.
get_timer_sec
(
'prefill'
)
decode_time
=
profiler
.
get_timer_sec
(
'decode'
)
prefill_count
=
profiler
.
get_counter
(
'prefill'
)
decode_count
=
profiler
.
get_counter
(
'decode'
)
logger
.
info
(
f
'Performance(T/s): prefill
{
prefill_count
/
prefill_time
}
, decode
{
decode_count
/
decode_time
}
. Time(s): tokenize
{
tokenize_time
}
, prefill
{
prefill_time
}
, decode
{
decode_time
}
'
)
except
:
logger
.
info
(
f
'Performance statistics not recorded'
)
class
Engine
:
sched_client
:
SchedulerClient
updates
:
list
[
sched_ext
.
QueryUpdate
]
batch
:
sched_ext
.
BatchQueryTodo
model_runner
:
ModelRunner
sampler
:
Sampler
query_manager
:
QueryManager
cache
:
KDeepSeekV3Cache
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
):
self
.
args
=
args
# 子进程和父进程无法共享 config 变量
for
key
,
value
in
vars
(
args
).
items
():
if
value
is
not
None
and
hasattr
(
Config
(),
key
):
setattr
(
Config
(),
key
,
value
)
self
.
device
=
self
.
args
.
device
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
updates
=
[]
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
self
.
gen_queue
=
generated_token_queue
print
(
f
"Getting inference context from sched_client."
)
inference_context
=
self
.
sched_client
.
get_inference_context_raw
()
print
(
f
"Got inference context, sending it to subscribers."
)
inference_context
=
self
.
sched_client
.
rebuild_inferece_context
(
inference_context
)
self
.
cache
.
load
(
inference_context
)
print
(
f
"kv_cache loaded successfully."
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
self
.
model
=
KDeepseekV2ForCausalLM
(
config
,
self
.
cache
)
# print(self.block_num)
context
=
zmq
.
Context
()
self
.
pub_socket
=
context
.
socket
(
zmq
.
PUB
)
self
.
pub_socket
.
bind
(
f
"ipc://
{
broadcast_endpoint
}
"
)
# time.sleep(1) # make sure all subscribers are ready
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
do_sample
=
True
)
if
args
.
optimize_config_path
is
None
:
optimize_config_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_config_path
=
args
.
optimize_config_path
gguf_path
=
args
.
gguf_path
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf
(
self
.
model
,
optimize_config_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
eval
()
#@TODO add config
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
self
.
model_runner
=
ModelRunner
(
self
.
model
,
self
.
device
,
self
.
args
.
use_cuda_graph
,
page_size
=
args
.
page_size
)
self
.
sampler
=
Sampler
()
self
.
query_manager
=
QueryManager
(
device
=
self
.
device
,
page_size
=
args
.
page_size
)
def
sampling
(
self
,
forward_output
:
ForwardBatchOutput
):
generated_tokens
=
torch
.
empty
(
0
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
for
i
in
range
(
forward_output
.
num_batchs
):
logit
=
forward_output
.
logits
[
i
]
if
hasattr
(
forward_output
,
"temperatures"
):
temperatures
=
forward_output
.
temperatures
[
i
]
else
:
temperatures
=
None
if
hasattr
(
forward_output
,
"top_ps"
):
top_ps
=
forward_output
.
top_ps
[
i
]
else
:
top_ps
=
None
sample_options
=
SamplingOptions
(
logit
.
size
(
0
),
self
.
device
,
pretrained_config
=
self
.
model
.
generation_config
,
temperatures
=
temperatures
,
top_ps
=
top_ps
)
generated_tokens
,
probs
=
self
.
sampler
(
logit
,
sample_options
)
return
generated_tokens
,
probs
def
loop
(
self
):
next_batch
=
None
while
True
:
self
.
batch
=
next_batch
if
self
.
batch
is
not
None
:
self
.
model_runner
.
run
(
self
.
batch
,
self
.
query_manager
)
if
len
(
self
.
updates
)
>
0
:
for
q
in
self
.
updates
:
if
q
.
is_prefill
==
True
:
continue
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
try
:
self
.
gen_queue
.
put
((
q
.
id
,
q
.
generated_token
if
q
.
decode_done
==
False
else
None
),
timeout
=
5
)
except
queue
.
Full
:
pass
#print("Queue is full after timeout; unable to put more items.")
next_batch
=
self
.
sched_client
.
update_last_batch
(
self
.
updates
)
if
next_batch
.
query_ids
==
[]:
next_batch
=
None
self
.
pub_socket
.
send_pyobj
(
next_batch
)
if
next_batch
is
not
None
:
self
.
query_manager
.
add_query
(
next_batch
)
if
self
.
batch
is
not
None
:
self
.
model_runner
.
sync
()
print
(
f
"Model execution time (GPU):
{
self
.
model_runner
.
model_time
:.
3
f
}
ms"
)
# if self.rank == 0:
generated_tokens
,
probs
=
self
.
sampling
(
self
.
model_runner
.
output
)
self
.
updates
=
self
.
query_manager
.
update
(
self
.
batch
)
fill_generated_tokens
(
self
.
updates
,
generated_tokens
,
self
.
query_manager
)
else
:
self
.
updates
=
[]
class
BalanceServeThreadContext
(
ThreadContext
):
def
get_local_messages
(
self
):
local_messages
=
[]
for
m
in
self
.
messages
:
local_messages
.
append
({
"role"
:
m
.
role
.
value
,
"content"
:
m
.
get_text_content
()})
return
local_messages
def
run_engine
(
args
,
token_queue
,
broadcast_endpoint
,
event
):
engine
=
Engine
(
args
,
token_queue
,
broadcast_endpoint
)
if
args
.
use_cuda_graph
:
engine
.
model_runner
.
warmup
()
event
.
set
()
engine
.
loop
()
class
BalanceServeInterface
(
BackendInterfaceBase
):
use_static_cache
:
bool
=
True
model
:
Any
tokenizer
:
AutoTokenizer
cache
:
StaticCache
generated_ids
:
torch
.
Tensor
seq_length
:
int
streamer
:
TextStreamer
# thread_related
last_request_id
:
Optional
[
str
]
=
None
ever_generated_ids
:
Set
[
int
]
=
set
()
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
self
.
args
=
args
self
.
queue_map
:
dict
[
int
,
asyncio
.
Queue
]
=
{}
self
.
thread_map
:
dict
[
int
,
int
]
=
{}
processes
=
[]
self
.
broadcast_endpoint
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
# @TODO add to config
ctx
=
mp
.
get_context
(
"spawn"
)
self
.
token_queue
=
ctx
.
Queue
(
maxsize
=
1000
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
start_event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
run_engine
,
args
=
(
self
.
args
,
self
.
token_queue
,
self
.
broadcast_endpoint
,
start_event
))
p
.
start
()
processes
.
append
(
p
)
start_event
.
wait
()
def
run_queue_proxy
(
self
):
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
run_until_complete
(
self
.
queue_proxy
())
@
asynccontextmanager
async
def
lifespan
(
self
,
app
:
FastAPI
):
asyncio
.
create_task
(
self
.
queue_proxy
())
yield
async
def
queue_proxy
(
self
):
print
(
"Queue Proxy Started"
)
while
True
:
try
:
query_id
,
token
=
self
.
token_queue
.
get_nowait
()
try
:
# query id might not be allocated yet
self
.
queue_map
[
query_id
].
put_nowait
(
token
)
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
except
asyncio
.
QueueFull
:
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
await
self
.
queue_map
[
query_id
].
put
(
token
)
except
queue
.
Empty
:
# print("no new token")
# await asyncio.sleep(1)
await
asyncio
.
sleep
(
0
)
def
tokenize_prompt
(
self
,
prompt
:
str
):
input_ids
=
self
.
tokenizer
.
encode
(
prompt
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
return
input_ids
def
format_and_tokenize_input_ids
(
self
,
thread_id
:
ObjectID
,
messages
:
List
):
for
m
in
messages
:
if
m
[
"role"
]
==
"system"
:
logger
.
warning
(
f
'change
{
m
[
"role"
]
}
to user'
)
m
[
"role"
]
=
"user"
new_messages
=
[
messages
[
0
]]
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
new_messages
.
append
(
m
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
logger
.
debug
(
f
"get input ids of shape
{
input_ids
.
shape
}
"
)
return
input_ids
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
profiler
=
Profiler
()
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
)
profiler
.
pause_timer
(
"tokenize"
)
profiler
.
create_and_start_timer
(
"prefill"
)
query_add
=
sched_ext
.
QueryAdd
()
query_add
.
query_token
=
input_ids
[
0
].
tolist
()
query_length
=
input_ids
[
0
].
shape
[
0
]
query_add
.
query_length
=
query_length
profiler
.
set_counter
(
"prefill"
,
query_length
)
#@TODO add server
stop_criteria
=
[
self
.
tokenizer
.
encode
(
self
.
tokenizer
.
eos_token
,
add_special_tokens
=
False
),
self
.
tokenizer
.
encode
(
"<|im_end|>"
)]
query_add
.
stop_criteria
=
stop_criteria
query_add
.
sample_options
.
temperature
=
temperature
query_add
.
sample_options
.
top_p
=
top_p
query_add
.
estimated_length
=
min
(
self
.
args
.
cache_lens
,
query_length
+
self
.
args
.
max_new_tokens
)
query_id
=
self
.
sched_client
.
add_query
(
query_add
)
queue
=
asyncio
.
Queue
(
maxsize
=
self
.
args
.
max_new_tokens
)
self
.
queue_map
[
query_id
]
=
queue
self
.
thread_map
[
thread_id
]
=
query_id
is_first_token
=
True
async
for
token
in
chat_stream
(
self
.
queue_map
[
query_id
],
self
.
tokenizer
):
if
is_first_token
:
is_first_token
=
False
profiler
.
pause_timer
(
"prefill"
)
profiler
.
create_and_start_timer
(
"decode"
)
profiler
.
set_counter
(
"decode"
,
0
)
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
,
None
else
:
profiler
.
inc
(
"decode"
)
yield
token
,
None
profiler
.
pause_timer
(
"decode"
)
report_last_time_performance
(
profiler
)
yield
self
.
streamer
.
end
(),
None
if
profiler
.
get_counter
(
'decode'
)
>=
self
.
args
.
max_new_tokens
-
1
:
yield
""
,
"length"
else
:
yield
""
,
"stop"
yield
RawUsage
(
tokenize_time
=
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
profiler
.
get_counter
(
'prefill'
),
decode_count
=
profiler
.
get_counter
(
'decode'
),
)
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment