Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
8db6a4d4
Unverified
Commit
8db6a4d4
authored
Feb 27, 2025
by
Atream
Committed by
GitHub
Feb 27, 2025
Browse files
Merge branch 'main' into main
parents
cea07d19
3c8c5805
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
774 additions
and
99 deletions
+774
-99
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+37
-18
ktransformers/operators/experts.py
ktransformers/operators/experts.py
+13
-4
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+30
-9
ktransformers/operators/gate.py
ktransformers/operators/gate.py
+10
-3
ktransformers/operators/linear.py
ktransformers/operators/linear.py
+68
-5
ktransformers/operators/models.py
ktransformers/operators/models.py
+9
-4
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
...imize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
+63
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
...optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
+4
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
...s/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
+157
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
...ize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+2
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+1
-0
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
+11
-0
ktransformers/server/api/ollama/completions.py
ktransformers/server/api/ollama/completions.py
+94
-39
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+8
-3
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+2
-2
ktransformers/tests/mmlu_pro_test.py
ktransformers/tests/mmlu_pro_test.py
+2
-2
ktransformers/tests/triton_fp8gemm_test.py
ktransformers/tests/triton_fp8gemm_test.py
+116
-0
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+26
-1
ktransformers/util/custom_loader.py
ktransformers/util/custom_loader.py
+86
-0
ktransformers/util/utils.py
ktransformers/util/utils.py
+35
-7
No files found.
ktransformers/operators/attention.py
View file @
8db6a4d4
...
...
@@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
from
typing
import
Optional
,
Tuple
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
ktransformers.util.utils
import
get_compute_capability
import
logging
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.cache_utils
import
Cache
...
...
@@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
prefill_device
:
str
=
"cuda"
,
generate_device
:
str
=
"cuda"
,
chunck_size
:
int
=
1000
,
absorb_for_prefill
:
bool
=
False
,
**
kwargs
):
BaseInjectedModule
.
__init__
(
self
,
key
,
gguf_loader
,
config
,
orig_module
,
prefill_device
,
generate_device
,
**
kwargs
)
self
.
orig_module
.
__init__
(
orig_module
.
config
,
orig_module
.
layer_idx
)
self
.
chunck_size
=
chunck_size
# TODO, generate chunck_size automatically.
self
.
mla_wrapper
=
None
self
.
absorb_for_prefill
=
absorb_for_prefill
def
get_absorbed
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
(
hasattr
(
self
,
'q_absorb'
)
and
hasattr
(
self
,
'out_absorb'
)):
...
...
@@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
#
assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
...
...
@@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
@@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if
q_len
==
1
:
if
q_len
==
1
or
self
.
absorb_for_prefill
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
# Specific to RoPE models
compressed_kv_with_k_pe
,
page_table
=
past_key_value
.
update
(
compressed_kv
,
k_pe
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -395,29 +399,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope
=
q_nope
.
transpose
(
1
,
2
)
# q_len is 1, no GPU overhead, same below
q_nope
=
torch
.
matmul
(
q_nope
,
q_absorb
)
# batched MM
q_nope
=
q_nope
.
transpose
(
1
,
2
)
assert
q_nope
.
is_contiguous
()
q_nope
=
q_nope
.
contiguous
()
#assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope
.
squeeze_
(
1
)
q_pe
.
squeeze_
(
1
)
q_nope
.
squeeze_
(
0
)
q_pe
.
squeeze_
(
0
)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if
self
.
mla_wrapper
is
None
:
self
.
mla_wrapper
=
MLAWrapperSingleton
.
get_instance
(
self
.
device
,
1
,
past_key_value
.
max_pages
,
use_cuda_graph
=
True
)
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
if
self
.
mla_wrapper
.
need_plan
:
self
.
mla_wrapper
.
need_plan
=
False
if
q_len
==
1
:
self
.
mla_wrapper
.
plan
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
position_ids
.
squeeze
(
1
)
+
1
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
else
:
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
kv_len_arr
=
torch
.
tensor
([
position_ids
[
0
,
-
1
].
item
()
+
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
mla_wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_len_arr
,
self
.
num_heads
,
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
,
past_key_value
.
page_size
,
self
.
softmax_scale
,
q_nope
.
dtype
,
compressed_kv
.
dtype
)
attn_output
=
self
.
mla_wrapper
.
run
(
q_nope
,
q_pe
,
compressed_kv
,
k_pe
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
kv_lora_rank
)
"""
k = (
torch.cat([compressed_kv, k_pe], dim=-1)
...
...
@@ -443,10 +460,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output
=
attn_output
.
transpose
(
1
,
2
)
# [bsz, self.num_heads, q_len, self.kv_lora_rank]
attn_output
=
torch
.
matmul
(
attn_output
,
out_absorb
.
mT
)
# [bsz, self.num_heads, q_len, self.v_head_dim]
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
# [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
v_head_dim
)
# [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
else
:
if
past_key_value
is
not
None
:
...
...
@@ -571,7 +589,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position
:
Optional
[
torch
.
LongTensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
os
.
name
==
'nt'
:
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
return
self
.
forward_windows
(
hidden_states
,
attention_mask
,
...
...
ktransformers/operators/experts.py
View file @
8db6a4d4
...
...
@@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type
=
None
for
key
in
keys
:
if
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
# using a temp ugly way to temprary load the tensor
gate
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_gate_exps.weight"
).
numpy
()
up
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_up_exps.weight"
).
numpy
()
down
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_down_exps.weight"
).
numpy
()
gate_type
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_gate_exps.ggml_type"
).
item
()
up_type
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_up_exps.ggml_type"
).
item
()
down_type
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_down_exps.ggml_type"
).
item
()
elif
key
+
".ffn_gate_exps.weight"
in
self
.
gguf_loader
.
tensor_info
:
gate
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_gate_exps.weight"
)
up
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_up_exps.weight"
)
down
=
self
.
gguf_loader
.
get_mmap_tensor
(
key
+
".ffn_down_exps.weight"
)
...
...
@@ -450,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
self
.
up
[
i
]
=
w
[
"up"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
down
[
i
]
=
w
[
"down"
][
i
,
...].
to
(
device
=
device
,
dtype
=
self
.
dtype
)
self
.
up
=
torch
.
cat
(
self
.
up
,
dim
=
0
)
self
.
gate
=
torch
.
cat
(
self
.
gate
,
dim
=
0
)
self
.
down
=
torch
.
cat
(
self
.
down
,
dim
=
0
)
self
.
up
=
torch
.
stack
(
self
.
up
,
dim
=
0
)
self
.
gate
=
torch
.
stack
(
self
.
gate
,
dim
=
0
)
self
.
down
=
torch
.
stack
(
self
.
down
,
dim
=
0
)
return
def
unload
(
self
):
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
8db6a4d4
...
...
@@ -9,7 +9,7 @@ flashinfer_enabled = False
try
:
import
flashinfer
flashinfer_enabled
=
False
# disabled now, TODO:use new version of flashinfer and enabl
e
flashinfer_enabled
=
Tru
e
print
(
"found flashinfer"
)
except
ImportError
:
...
...
@@ -122,7 +122,7 @@ class MLAWrapper():
if
kv_indices
is
None
:
assert
self
.
max_batch_size
==
1
kv_indices
=
self
.
kv_indices_buf
self
.
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
...
...
@@ -132,14 +132,14 @@ class MLAWrapper():
head_dim_ckv
,
head_dim_kpe
,
page_size
,
Fals
e
,
# causal
is False for decoding
Tru
e
,
# causal
sm_scale
,
q_data_type
,
kv_data_type
,
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
wrappers
:
dict
=
{}
...
...
@@ -179,6 +179,24 @@ class MLAWrapperSingleton():
sm_scale
,
q_data_type
,
kv_data_type
,)
wrapper
.
need_plan
=
False
@
classmethod
def
need_plan_all
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
need_plan
=
True
@
classmethod
def
reset_buffer
(
cls
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
qo_indptr_buf
[
1
]
=
1
# assert max_batch_size=1 here.
@
classmethod
def
update_buffer
(
cls
,
max_pages
):
for
device
,
wrapper
in
cls
.
wrappers
.
items
():
wrapper
.
kv_indptr_buf
[
1
]
=
max_pages
# assert max_batch_size=1 here.
wrapper
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
wrapper
.
wrapper
.
_kv_indices_buf
=
wrapper
.
kv_indices_buf
if
__name__
==
"__main__"
:
...
...
@@ -187,8 +205,9 @@ if __name__ == "__main__":
page_size
=
64
num_heads
=
128
q_nope
=
torch
.
randn
((
1
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
1
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_len
=
10
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
=
torch
.
randn
((
max_pages
,
page_size
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
...
@@ -199,10 +218,10 @@ if __name__ == "__main__":
max_pages
,
)
kv_len_arr
=
torch
.
tensor
([
10
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_len_arr
=
torch
.
tensor
([
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
qo_indptr
,
None
,
None
,
kv_len_arr
,
...
...
@@ -216,6 +235,7 @@ if __name__ == "__main__":
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
...
...
@@ -235,6 +255,7 @@ if __name__ == "__main__":
False
,
192
**
(
-
0.5
)
)
print
(
attn_ref
.
shape
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/operators/gate.py
View file @
8db6a4d4
...
...
@@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for
key
in
keys
:
key
=
"."
.
join
(
key
.
split
(
"."
)[:
-
1
])
if
key
+
".ffn_gate_inp.weight"
in
self
.
gguf_loader
.
tensor_info
:
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
targets
=
[
".ffn_gate_inp.weight"
,
".exp_probs_b.bias"
]
weight
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".ffn_gate_inp.weight"
)
e_score_correction_bias
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
".exp_probs_b.bias"
)
weight_type
=
weight
.
dtype
e_score_correction_bias_type
=
e_score_correction_bias
.
dtype
res
=
{
"weight"
:
weight
,
"e_score_correction_bias"
:
e_score_correction_bias
,
"weight_type"
:
weight_type
,
"e_score_correction_bias_type"
:
e_score_correction_bias_type
}
elif
key
+
".ffn_gate_inp.weight"
in
self
.
gguf_loader
.
tensor_info
:
targets
=
[
".ffn_gate_inp.weight"
,
".exp_probs_b.bias"
]
tensors
=
self
.
load_multi
(
key
,
targets
,
device
=
device
)
weight
=
tensors
[
".ffn_gate_inp.weight"
]
...
...
@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
w
[
"e_score_correction_bias"
])
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
orig_module
.
weight
=
self
.
orig_module
.
weight
.
to
(
device
)
self
.
orig_module
.
e_score_correction_bias
=
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
self
.
orig_module
.
weight
=
nn
.
Parameter
(
self
.
orig_module
.
weight
.
to
(
device
)
)
self
.
orig_module
.
e_score_correction_bias
=
nn
.
Parameter
(
self
.
orig_module
.
e_score_correction_bias
.
to
(
device
)
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
...
...
ktransformers/operators/linear.py
View file @
8db6a4d4
...
...
@@ -26,6 +26,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
)
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
transformers.configuration_utils
import
PretrainedConfig
from
ktransformers.ktransformers_ext.triton.fp8gemm
import
fp8_gemm
,
act_quant
,
weight_dequant
from
abc
import
ABC
,
abstractmethod
import
sys
,
os
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
"ktransformers_ext"
,
"build"
))
...
...
@@ -78,7 +79,13 @@ class KLinearBase(ABC):
keys
=
[
self
.
key
]
for
key
in
keys
:
if
key
+
".weight"
in
self
.
gguf_loader
.
tensor_file_map
:
if
self
.
gguf_loader
.
safetensor_loader
is
not
None
:
# using safetensor_loader
tensor
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight'
)
weight_scale_inv
=
self
.
gguf_loader
.
safetensor_loader
.
load_tensor
(
key
+
'.weight_scale_inv'
)
return
nn
.
Parameter
(
tensor
),
nn
.
Parameter
(
weight_scale_inv
)
elif
key
+
".weight"
in
self
.
gguf_loader
.
tensor_file_map
:
if
key
+
".bias"
in
self
.
gguf_loader
.
tensor_file_map
:
tensors
=
self
.
load_multi
(
key
,
[
"weight"
,
"bias"
],
device
=
device
)
tensor
=
tensors
[
"weight"
]
...
...
@@ -169,7 +176,61 @@ class KLinearTorch(KLinearBase):
if
self
.
has_bias
:
self
.
bias
=
None
class
KLinearFP8
(
KLinearBase
):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
g_idx
:
torch
.
Tensor
sort_indices
:
torch
.
Tensor
has_bias
:
bool
weight
:
torch
.
Tensor
scale_w
:
torch
.
Tensor
bias
:
torch
.
Tensor
def
__init__
(
self
,
key
:
str
,
gguf_loader
:
GGUFLoader
,
config
:
PretrainedConfig
,
orig_module
:
nn
.
Module
=
None
,
device
:
str
=
"cuda"
,
block_size
:
int
=
128
,
**
kwargs
,
):
super
().
__init__
(
key
,
gguf_loader
,
config
,
orig_module
,
device
,
**
kwargs
)
self
.
has_bias
=
False
self
.
dtype
=
torch
.
get_default_dtype
()
self
.
block_size
=
block_size
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
.
to
(
self
.
device
)
orig_dtype
=
x
.
dtype
x_quantized
,
scale_x
=
act_quant
(
x
,
self
.
block_size
)
y
=
fp8_gemm
(
x_quantized
,
scale_x
,
self
.
weight
,
self
.
weight_scale_inv
)
return
y
.
to
(
dtype
=
orig_dtype
)
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
device
:
str
|
None
=
None
):
if
device
is
None
:
device
=
self
.
device
if
w
is
None
:
w
=
self
.
load_weight
(
device
=
device
)
### TODO fit weight_inv format
if
isinstance
(
w
,
tuple
):
self
.
weight
=
w
[
0
].
to
(
device
)
self
.
weight_scale_inv
=
w
[
1
].
to
(
device
)
self
.
has_bias
=
False
else
:
raise
ValueError
(
"Invalid weight type"
)
self
.
weight
=
self
.
weight
.
to
(
device
)
if
self
.
has_bias
:
self
.
bias
=
self
.
bias
.
to
(
device
)
def
unload
(
self
):
if
self
.
weight
is
not
None
:
self
.
weight
=
None
if
self
.
has_bias
:
self
.
bias
=
None
class
KLinearMarlin
(
KLinearBase
):
marlin_q_w
:
torch
.
Tensor
marlin_s
:
torch
.
Tensor
...
...
@@ -404,7 +465,8 @@ class KLinearCPUInfer(KLinearBase):
LINEAR_MAP
=
{
"KLinearMarlin"
:
KLinearMarlin
,
"KLinearTorch"
:
KLinearTorch
,
"KLinearCPUInfer"
:
KLinearCPUInfer
"KLinearCPUInfer"
:
KLinearCPUInfer
,
"KLinearFP8"
:
KLinearFP8
,
}
class
KTransformersLinear
(
BaseInjectedModule
,
KLinearBase
):
...
...
@@ -440,10 +502,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
def
forward
(
self
,
x
):
if
self
.
mode
==
InferenceState
.
PREFILL
:
assert
self
.
prefill_linear
is
not
None
,
"cpu linear is not initialized"
return
self
.
prefill_linear
.
forward
(
x
)
y
=
self
.
prefill_linear
.
forward
(
x
)
else
:
assert
self
.
generate_linear
is
not
None
,
"gpu linear is not initialized"
return
self
.
generate_linear
.
forward
(
x
)
y
=
self
.
generate_linear
.
forward
(
x
)
return
y
def
load
(
self
,
w
:
dict
|
nn
.
Parameter
|
tuple
|
None
=
None
,
mode
:
InferenceState
=
InferenceState
.
GENERATE
):
if
not
mode
:
...
...
ktransformers/operators/models.py
View file @
8db6a4d4
...
...
@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from
transformers.models.qwen2_moe.configuration_qwen2_moe
import
Qwen2MoeConfig
from
ktransformers.models.configuration_llama
import
LlamaConfig
from
ktransformers.operators.base_operator
import
BaseInjectedModule
from
ktransformers.util.utils
import
InferenceState
from
ktransformers.util.utils
import
InferenceState
,
get_compute_capability
from
ktransformers.util.custom_gguf
import
GGUFLoader
from
transformers.configuration_utils
import
PretrainedConfig
from
ktransformers.models.modeling_llama
import
(
...
...
@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if
per_layer_prefill_flag
:
causal_mask
=
None
else
:
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
if
os
.
name
==
'nt'
or
get_compute_capability
()
<
8
:
print
(
"for Windows or GPU before ampere, use forward_windows"
)
# only use mask in forward windows or can't flash attn
causal_mask
=
self
.
_update_causal_mask
(
attention_mask
,
inputs_embeds
,
cache_position
,
past_key_values
,
output_attentions
)
else
:
causal_mask
=
None
# embed positions
hidden_states
=
inputs_embeds
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml
0 → 100644
View file @
8db6a4d4
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
View file @
8db6a4d4
...
...
@@ -293,6 +293,7 @@
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
absorb_for_prefill
:
False
# GPU 1: layers 15–29
-
match
:
...
...
@@ -302,6 +303,7 @@
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
absorb_for_prefill
:
False
# GPU 2: layers 30–44
-
match
:
...
...
@@ -311,6 +313,7 @@
kwargs
:
generate_device
:
"
cuda:2"
prefill_device
:
"
cuda:2"
absorb_for_prefill
:
False
# GPU 3: layers 45–60
-
match
:
...
...
@@ -320,6 +323,7 @@
kwargs
:
generate_device
:
"
cuda:3"
prefill_device
:
"
cuda:3"
absorb_for_prefill
:
False
# === Overall Model Replacement with Transfer Map ===
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml
0 → 100644
View file @
8db6a4d4
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
."
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.(?!self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
generate_op
:
"
KLinearFP8"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoE
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:0"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:0"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExperts
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda:1"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda:1"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.KDeepseekV2Attention
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
transfer_map
:
30
:
"
cuda:1"
-
match
:
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
."
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^lm_head"
class
:
torch.nn.Linear
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
-
match
:
name
:
"
(^model
\\
.layers
\\
.([3456][0-9])
\\
.)|(model.norm)"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
View file @
8db6a4d4
...
...
@@ -168,5 +168,5 @@
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cuda:
0
"
prefill_device
:
"
cuda:
0
"
generate_device
:
"
cuda:
1
"
prefill_device
:
"
cuda:
1
"
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
8db6a4d4
...
...
@@ -60,6 +60,7 @@
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
View file @
8db6a4d4
...
...
@@ -53,6 +53,17 @@
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
...
...
ktransformers/server/api/ollama/completions.py
View file @
8db6a4d4
...
...
@@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config
from
ktransformers.server.utils.create_interface
import
get_interface
from
ktransformers.server.schemas.assistants.streaming
import
check_link_response
from
ktransformers.server.backend.base
import
BackendInterfaceBase
router
=
APIRouter
(
prefix
=
'/api'
)
router
=
APIRouter
(
prefix
=
'/api'
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class
OllamaGenerateCompletionRequest
(
BaseModel
):
...
...
@@ -40,61 +40,121 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive
:
Optional
[
str
]
=
Field
(
"5m"
,
description
=
"Controls how long the model will stay loaded into memory following the request."
)
class
OllamaGenerationStreamResponse
(
BaseModel
):
model
:
str
created_at
:
str
response
:
str
done
:
bool
=
Field
(...)
class
OllamaGenerationResponse
(
BaseModel
):
pass
@
router
.
post
(
"/generate"
,
tags
=
[
'ollama'
])
async
def
generate
(
request
:
Request
,
input
:
OllamaGenerateCompletionRequest
):
id
=
str
(
uuid4
())
interface
:
BackendInterfaceBase
=
get_interface
()
print
(
f
'COMPLETION INPUT:----
\n
{
input
.
prompt
}
\n
----'
)
config
=
Config
()
if
input
.
stream
:
async
def
inner
():
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
# yield f"{json.dumps(d)}\n"
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
# yield f"{json.dumps(d)}\n"
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
''
,
done
=
True
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
async
for
token
in
interface
.
inference
(
input
.
prompt
,
id
):
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
token
,
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
d
=
OllamaGenerationStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
response
=
''
,
done
=
True
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
else
:
raise
NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class
OllamaChatCompletionMessage
(
BaseModel
):
role
:
str
content
:
str
class
OllamaChatCompletionRequest
(
BaseModel
):
pass
model
:
str
=
Field
(...,
description
=
"The model name, which is required."
)
messages
:
List
[
OllamaChatCompletionMessage
]
=
Field
(
...,
description
=
"A list of messages to generate a response for."
)
stream
:
bool
=
Field
(
True
,
description
=
"If true, the response will be streamed."
)
class
OllamaChatCompletionStreamResponse
(
BaseModel
):
pass
model
:
str
created_at
:
str
message
:
dict
done
:
bool
=
Field
(...)
total_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Total time spent in nanoseconds"
)
load_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent loading model in nanoseconds"
)
prompt_eval_count
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Number of tokens in prompt"
)
prompt_eval_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent evaluating prompt in nanoseconds"
)
eval_count
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Number of tokens generated"
)
eval_duration
:
Optional
[
int
]
=
Field
(
None
,
description
=
"Time spent generating response in nanoseconds"
)
class
OllamaChatCompletionResponse
(
BaseModel
):
pass
@
router
.
post
(
"/chat"
,
tags
=
[
'ollama'
])
async
def
chat
(
request
:
Request
,
input
:
OllamaChatCompletionRequest
):
raise
NotImplementedError
id
=
str
(
uuid4
())
interface
:
BackendInterfaceBase
=
get_interface
()
config
=
Config
()
# 将消息转换为提示字符串
prompt
=
""
for
msg
in
input
.
messages
:
prompt
+=
f
"
{
msg
.
role
}
:
{
msg
.
content
}
\n
"
prompt
+=
"assistant:"
if
input
.
stream
:
async
def
inner
():
start_time
=
time
()
# 记录开始时间(秒)
eval_count
=
0
# 统计生成的 token 数量
tokens
=
[]
async
for
token
in
interface
.
inference
(
prompt
,
id
):
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{
"role"
:
"assistant"
,
"content"
:
token
},
done
=
False
)
yield
d
.
model_dump_json
()
+
'
\n
'
# 计算性能数据
end_time
=
time
()
total_duration
=
int
((
end_time
-
start_time
)
*
1_000_000_000
)
# 转换为纳秒
prompt_eval_count
=
len
(
prompt
.
split
())
# 简单估算提示词数量
eval_duration
=
total_duration
# 假设全部时间用于生成(简化)
prompt_eval_duration
=
0
# 假设无单独提示评估时间
load_duration
=
0
# 假设加载时间未知
d
=
OllamaChatCompletionStreamResponse
(
model
=
config
.
model_name
,
created_at
=
str
(
datetime
.
now
()),
message
=
{},
done
=
True
,
total_duration
=
total_duration
,
load_duration
=
load_duration
,
prompt_eval_count
=
prompt_eval_count
,
prompt_eval_duration
=
prompt_eval_duration
,
eval_count
=
eval_count
,
eval_duration
=
eval_duration
)
yield
d
.
model_dump_json
()
+
'
\n
'
return
check_link_response
(
request
,
inner
())
else
:
raise
NotImplementedError
(
"Non-streaming chat is not implemented."
)
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class
OllamaModel
(
BaseModel
):
...
...
@@ -103,9 +163,8 @@ class OllamaModel(BaseModel):
size
:
int
# TODO: fill the rest correctly
# mock ollama
@
router
.
get
(
"/tags"
,
tags
=
[
'ollama'
])
@
router
.
get
(
"/tags"
,
tags
=
[
'ollama'
])
async
def
tags
():
config
=
Config
()
# TODO: fill this correctly, although it does not effect Tabby
...
...
@@ -138,25 +197,21 @@ class OllamaShowResponse(BaseModel):
class
Config
:
protected_namespaces
=
()
@
router
.
post
(
"/show"
,
tags
=
[
'ollama'
])
async
def
show
(
request
:
Request
,
input
:
OllamaShowRequest
):
config
=
Config
()
# TODO: Add more info in config to return, although it does not effect Tabby
return
OllamaShowResponse
(
modelfile
=
"# Modelfile generated by ..."
,
parameters
=
" "
,
template
=
" "
,
details
=
OllamaShowDetial
(
parent_model
=
" "
,
format
=
"gguf"
,
family
=
" "
,
families
=
[
" "
],
parameter_size
=
" "
,
quantization_level
=
" "
modelfile
=
"# Modelfile generated by ..."
,
parameters
=
" "
,
template
=
" "
,
details
=
OllamaShowDetial
(
parent_model
=
" "
,
format
=
"gguf"
,
family
=
" "
,
families
=
[
" "
],
parameter_size
=
" "
,
quantization_level
=
" "
),
model_info
=
OllamaModelInfo
()
model_info
=
OllamaModelInfo
()
)
\ No newline at end of file
ktransformers/server/backend/interfaces/ktransformers.py
View file @
8db6a4d4
...
...
@@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
...
...
@@ -35,9 +36,9 @@ class KTransformersInterface(TransformersInterface):
with
torch
.
device
(
"meta"
):
self
.
model
=
custom_models
[
config
.
architectures
[
0
]](
config
)
if
default_args
.
optimize_config_path
is
None
:
optimize_
rule
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
optimize_
config
_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_
rule
_path
=
args
.
optimize_config_path
optimize_
config
_path
=
args
.
optimize_config_path
# print(optimize_config)
...
...
@@ -47,7 +48,7 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf
(
self
.
model
,
optimize_
rule
_path
,
gguf_path
,
config
)
optimize_and_load_gguf
(
self
.
model
,
optimize_
config
_path
,
gguf_path
,
config
)
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
...
...
@@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
...
...
@@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
8db6a4d4
...
...
@@ -333,14 +333,14 @@ class TransformersInterface(BackendInterfaceBase):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
if
i
>
1
and
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
(
self
.
model
.
config
.
qk_rope_head_dim
+
self
.
model
.
config
.
qk_nope_head_dim
)
**
(
-
0.5
),
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
:
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
)
:
assert
self
.
args
.
batch_size
==
1
break
yield
self
.
append_new_tokens
(
next_token
)
...
...
ktransformers/tests/mmlu_pro_test.py
View file @
8db6a4d4
...
...
@@ -173,8 +173,8 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--concurrent"
,
type
=
int
,
default
=
1000
,
help
=
"Number of concurrent evaluations"
)
parser
.
add_argument
(
"--file"
,
type
=
str
,
default
=
"TIGER-Lab/MMLU-Pro"
,
help
=
"Path to the mmlu.jsonl file"
)
parser
.
add_argument
(
"--result"
,
type
=
str
,
default
=
"./mmlu_pro.json"
,
help
=
"Path to save the result JSON file"
)
parser
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"./mmlu_pro.log"
,
help
=
"Path to save the log file"
)
parser
.
add_argument
(
"--result"
,
type
=
str
,
default
=
"./mmlu_
result_
pro.json"
,
help
=
"Path to save the result JSON file"
)
parser
.
add_argument
(
"--log"
,
type
=
str
,
default
=
"./mmlu_
result_
pro.log"
,
help
=
"Path to save the log file"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"Model name or path"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:15488/v1/chat/completions"
,
help
=
"API URL"
)
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
...
...
ktransformers/tests/triton_fp8gemm_test.py
0 → 100644
View file @
8db6a4d4
import
torch
import
torch.nn.functional
as
F
from
typing
import
Optional
import
pytest
from
typing
import
Tuple
,
Optional
,
Literal
import
time
# use dir path
import
os
import
sys
sys
.
path
.
insert
(
0
,
"/home/azure/ktransformers"
)
print
(
sys
.
path
)
from
ktransformers.ktransformers_ext.triton.fp8gemm
import
fp8_gemm
,
act_quant
,
weight_dequant
from
safetensors
import
safe_open
world_size
=
1
rank
=
0
block_size
=
128
gemm_impl
:
Literal
[
"bf16"
,
"fp8"
]
=
"bf16"
# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined
def
test_fp8_gemm_vs_torch_matmul
():
# Test case 1: Create random matrices of size (M, K) and (K, N)
M
,
K
,
N
=
64
,
128
,
256
# Matrix dimensions
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
weight
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
# Apply act_quant to both matrices
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
weight_quantized
,
scale_w
=
act_quant
(
weight
,
block_size
)
# mk continous
x_quantized
=
x_quantized
.
contiguous
()
weight_quantized
=
weight_quantized
.
contiguous
()
scale_x
=
scale_x
.
contiguous
()
scale_w
=
scale_w
.
contiguous
()
# Perform fp8_gemm using the quantized tensors
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight_quantized
,
scale_w
)
# Perform torch.matmul using the original floating point tensors
result_torch_matmul
=
torch
.
matmul
(
x
,
weight
.
T
)
print
(
f
'result_torch_matmul:
{
result_torch_matmul
.
shape
}
'
)
print
(
f
'result_fp8_gemm:
{
result_fp8_gemm
.
shape
}
'
)
print
(
f
"result_fp8_gemm:
\n
{
result_fp8_gemm
}
"
)
print
(
f
"result_torch_matmul:
\n
{
result_torch_matmul
}
"
)
def
test_fp8_gemm_vs_torch_matmul_load
():
file_path
=
"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
0
)
as
f
:
weight
=
f
.
get_tensor
(
"model.layers.0.mlp.down_proj.weight"
)
scale
=
f
.
get_tensor
(
"model.layers.0.mlp.down_proj.weight_scale_inv"
)
# weight_dequant
weight_dequantized
=
weight_dequant
(
weight
,
scale
)
print
(
f
"weight_dequantized:
{
weight_dequantized
.
shape
}
"
)
N
,
K
=
weight_dequantized
.
shape
M
=
64
x
=
torch
.
randn
(
2
,
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
# Test case 1: quantized x matmal with undequantized weight
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
print
(
f
"result_fp8_gemm:
\n
{
result_fp8_gemm
}
"
)
print
(
f
"dtype
{
result_fp8_gemm
.
dtype
}
"
)
# Perform torch.matmul using the original floating point tensors
result_torch_matmul
=
torch
.
matmul
(
x
,
weight_dequantized
.
to
(
torch
.
bfloat16
).
T
)
print
(
f
"result_torch_matmul:
\n
{
result_torch_matmul
}
"
)
def
test_fp8_gemm_tplops
():
file_path
=
"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
0
)
as
f
:
weight
=
f
.
get_tensor
(
"model.layers.0.mlp.down_proj.weight"
)
scale
=
f
.
get_tensor
(
"model.layers.0.mlp.down_proj.weight_scale_inv"
)
# weight_dequant
weight_dequantized
=
weight_dequant
(
weight
,
scale
)
print
(
f
"weight_dequantized:
{
weight_dequantized
.
shape
}
"
)
N
,
K
=
weight_dequantized
.
shape
M
=
6400
x
=
torch
.
randn
(
2
,
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
# x_quantized, scale_x = act_quant(x, block_size)
# Calculate time for 1000 fp8_gemm
i
=
10
flops_per_gemm
=
2
*
M
*
N
*
K
total_flops
=
i
*
flops_per_gemm
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
t0
=
time
.
time
()
torch
.
cuda
.
synchronize
()
for
i
in
range
(
i
):
x_quantized
,
scale_x
=
act_quant
(
x
,
block_size
)
result_fp8_gemm
=
fp8_gemm
(
x_quantized
,
scale_x
,
weight
,
scale
)
torch
.
cuda
.
synchronize
()
t1
=
time
.
time
()
total_time
=
t1
-
t0
tflops
=
total_flops
/
total_time
/
1e12
print
(
f
"total_time:
{
total_time
}
"
)
print
(
f
"tflops:
{
tflops
}
"
)
if
__name__
==
"__main__"
:
test_fp8_gemm_vs_torch_matmul
()
test_fp8_gemm_vs_torch_matmul_load
()
test_fp8_gemm_tplops
()
\ No newline at end of file
ktransformers/util/custom_gguf.py
View file @
8db6a4d4
...
...
@@ -25,6 +25,7 @@ import os
from
enum
import
IntEnum
import
torch
import
KTransformersOps
from
.custom_loader
import
SafeTensorLoader
import
ctypes
class
GGMLQuantizationType
(
IntEnum
):
...
...
@@ -128,6 +129,7 @@ GGML_BLOCK_SIZES = {
"Q5_K"
:
2
+
2
+
12
+
256
//
8
+
256
//
2
,
"Q6_K"
:
256
//
2
+
256
//
4
+
256
//
16
+
2
,
"IQ4_XS"
:
2
+
2
+
256
//
2
+
256
//
64
,
"FP8"
:
1
,
}
GGML_ELEMENTS_PER_BLOCK
=
{
...
...
@@ -143,6 +145,7 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q5_K"
:
256
,
"Q6_K"
:
256
,
"IQ4_XS"
:
256
,
"FP8"
:
1
,
}
DATA_TYPES
=
{
...
...
@@ -159,6 +162,7 @@ DATA_TYPES = {
"uint64"
:
10
,
"int64"
:
11
,
"float64"
:
12
,
"FP8"
:
13
,
}
class
GGUFLoader
:
...
...
@@ -166,12 +170,15 @@ class GGUFLoader:
gguf_path
:
str
tensor_file_map
:
dict
# {tensor_name: tensor_file_path}
gguf_file_meta
:
dict
safetensor_loader
:
SafeTensorLoader
def
__init__
(
self
,
gguf_path
:
str
):
# Check dir exist
if
not
os
.
path
.
exists
(
gguf_path
):
raise
FileNotFoundError
(
f
"GGUF dir not found:
{
gguf_path
}
"
)
if
os
.
path
.
isfile
(
gguf_path
):
gguf_path
=
os
.
path
.
dirname
(
gguf_path
)
self
.
safetensor_loader
=
None
self
.
tensor_info
=
{}
self
.
gguf_path
=
gguf_path
...
...
@@ -179,7 +186,13 @@ class GGUFLoader:
self
.
file_data_map
=
{}
self
.
gguf_file_meta
=
{}
self
.
tensor_device_map
=
{}
# I know this is ugly, but I don't want to change the original code too much
# TODO: merge gguf load and other loads.
safetensor_loader
=
SafeTensorLoader
(
gguf_path
)
if
safetensor_loader
.
tensor_file_map
:
self
.
safetensor_loader
=
safetensor_loader
return
# Walk through all the .gguf files in the directory
found_gguf
=
False
for
root
,
dirs
,
files
in
os
.
walk
(
gguf_path
):
...
...
@@ -286,6 +299,13 @@ class GGUFLoader:
itemsize
=
int
(
np
.
empty
([],
dtype
=
item_type
).
itemsize
)
return
mmap_data
[
offset
:
offset
+
itemsize
*
item_count
]
def
get_undequanted_tensor_and_ggml_type
(
self
,
name
):
t
=
self
.
tensor_info
[
name
]
data
=
self
.
get_mmap_tensor
(
name
)
ggml_type
=
t
[
"ggml_type"
]
data
=
torch
.
from_numpy
(
data
)
return
data
,
ggml_type
def
load_expert_tensor
(
self
,
name
,
data
,
expert_id
,
elements_per_expert
,
device
=
"cuda"
,
target_dtype
=
torch
.
get_default_dtype
())
->
torch
.
Tensor
:
t
=
self
.
tensor_info
[
name
]
if
device
.
lower
()
==
"cpu"
:
...
...
@@ -310,6 +330,8 @@ class GGUFLoader:
values
=
GGML_DEQUANTIZE
[
ggml_name
](
data
)
values
=
torch
.
from_numpy
(
values
.
copy
())
if
ggml_name
==
"BF16"
:
values
=
values
.
view
(
torch
.
bfloat16
)
values
=
values
.
view
(
shape
[
-
2
::
-
1
])
return
values
...
...
@@ -418,6 +440,9 @@ def read_value(f, data_type):
elem_type
,
count
=
struct
.
unpack
(
"<IQ"
,
f
.
read
(
4
+
8
))
return
[
read_value
(
f
,
elem_type
)
for
_
in
range
(
count
)]
elif
data_type
==
DATA_TYPES
[
"FP8"
]:
return
struct
.
unpack
(
"<B"
,
f
.
read
(
1
))[
0
]
else
:
raise
NotImplementedError
(
f
"Data type
{
data_type
}
not implemented"
)
...
...
ktransformers/util/custom_loader.py
0 → 100644
View file @
8db6a4d4
import
struct
import
warnings
import
numpy
as
np
import
re
import
numpy.typing
as
npt
from
typing
import
Sequence
import
os
from
enum
import
IntEnum
import
torch
import
KTransformersOps
from
safetensors
import
safe_open
from
ktransformers.ktransformers_ext.triton.fp8gemm
import
fp8_gemm
,
act_quant
,
weight_dequant
from
safetensors.torch
import
save_file
class
SafeTensorLoader
:
tensor_file_map
=
{}
tensor_type_map
=
{}
file_handle_map
=
{}
def
__init__
(
self
,
file_path
:
str
):
self
.
__load_tensor_file_map
(
file_path
)
def
__load_tensor_file_map
(
self
,
file_path
:
str
):
# 处理传入路径,确保是文件夹路径
if
not
os
.
path
.
exists
(
file_path
):
raise
FileNotFoundError
(
f
"Path not found:
{
file_path
}
"
)
if
os
.
path
.
isfile
(
file_path
):
folder_path
=
os
.
path
.
dirname
(
file_path
)
else
:
folder_path
=
file_path
found_safetensor
=
False
for
root
,
_
,
files
in
os
.
walk
(
folder_path
):
files
=
sorted
(
files
)
for
file
in
files
:
if
file
.
endswith
(
".safetensors"
):
found_safetensor
=
True
file_path
=
os
.
path
.
join
(
root
,
file
)
if
file
not
in
self
.
file_handle_map
:
try
:
handle
=
safe_open
(
file_path
,
framework
=
"pt"
)
self
.
file_handle_map
[
file
]
=
handle
except
Exception
as
e
:
print
(
f
"Error opening Safetensor file
{
file_path
}
:
{
e
}
"
)
continue
f
=
self
.
file_handle_map
.
get
(
file
)
if
f
is
None
:
continue
try
:
for
key
in
f
.
keys
():
self
.
tensor_file_map
[
key
]
=
file
except
Exception
as
e
:
print
(
f
"Error reading Safetensor file
{
file_path
}
:
{
e
}
"
)
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def
load_tensor
(
self
,
key
:
str
,
device
:
str
=
"cpu"
):
if
key
not
in
self
.
tensor_file_map
:
raise
KeyError
(
f
"Key
{
key
}
not found in Safetensor files"
)
file
=
self
.
tensor_file_map
[
key
]
f
=
self
.
file_handle_map
.
get
(
file
)
if
f
is
None
:
raise
FileNotFoundError
(
f
"File
{
file
}
not found in Safetensor files"
)
tensor
=
f
.
get_tensor
(
key
)
return
tensor
.
to
(
device
)
def
close_all_handles
(
self
):
for
handle
in
self
.
file_handle_map
.
values
():
handle
.
close
()
self
.
file_handle_map
.
clear
()
def
load_dequantized_tensor
(
self
,
key
:
str
,
device
:
str
=
"cpu"
):
if
key
not
in
self
.
tensor_file_map
:
raise
KeyError
(
f
"Key
{
key
}
not found in Safetensor files"
)
file
=
self
.
tensor_file_map
[
key
]
f
=
self
.
file_handle_map
.
get
(
file
)
if
f
is
None
:
raise
FileNotFoundError
(
f
"File
{
file
}
not found in Safetensor files"
)
tensor
=
f
.
get_tensor
(
key
).
to
(
device
)
if
key
.
endswith
(
".weight"
):
if
key
[:
-
7
]
+
".weight_scale_inv"
in
self
.
tensor_file_map
:
weight_scale_inv
=
f
.
get_tensor
(
key
[:
-
7
]
+
".weight_scale_inv"
).
to
(
device
)
tensor
=
weight_dequant
(
tensor
,
weight_scale_inv
)
return
tensor
.
to
(
device
)
\ No newline at end of file
ktransformers/util/utils.py
View file @
8db6a4d4
...
...
@@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped
=
False
def
get_compute_capability
(
device
:
torch
.
device
=
None
):
if
torch
.
cuda
.
is_available
():
if
device
is
None
:
num_gpus
=
torch
.
cuda
.
device_count
()
min_compute_capability_major
=
100
for
gpu_id
in
range
(
num_gpus
):
gpu_props
=
torch
.
cuda
.
get_device_properties
(
gpu_id
)
min_compute_capability_major
=
min
(
min_compute_capability_major
,
gpu_props
.
major
)
return
min_compute_capability_major
else
:
return
torch
.
cuda
.
get_device_properties
(
device
)
def
set_module
(
model
,
submodule_key
,
module
):
tokens
=
submodule_key
.
split
(
'.'
)
sub_tokens
=
tokens
[:
-
1
]
...
...
@@ -66,13 +78,22 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for
name
,
param
in
local_state
.
items
():
key
=
prefix
+
name
translated_key
=
translate_name_to_gguf
(
key
)
if
translated_key
in
gguf_loader
.
tensor_file_map
:
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if
gguf_loader
.
safetensor_loader
is
not
None
:
load_dequantized_tensor
=
gguf_loader
.
safetensor_loader
.
load_dequantized_tensor
tensor_file_map
=
gguf_loader
.
safetensor_loader
.
tensor_file_map
else
:
load_dequantized_tensor
=
gguf_loader
.
load_gguf_tensor
tensor_file_map
=
gguf_loader
.
tensor_file_map
if
translated_key
in
tensor_file_map
:
target_dtype
=
torch
.
get_default_dtype
()
device
=
get_device
(
translated_key
[:
translated_key
.
rfind
(
"."
)],
gguf_loader
.
tensor_device_map
)
print
(
f
"loading
{
translated_key
}
to
{
device
}
"
)
torch
.
cuda
.
empty_cache
()
# device = "cpu" if "embd" in translated_key else "cuda"
weights
=
gguf_loader
.
load_gguf_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
weights
=
load_dequantized_tensor
(
translated_key
,
device
=
device
).
to
(
dtype
=
target_dtype
)
set_param
(
module
,
name
,
weights
)
del
weights
else
:
...
...
@@ -154,6 +175,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
update_buffer
(
past_key_values
.
max_pages
)
MLAWrapperSingleton
.
need_plan_all
()
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
...
...
@@ -176,6 +201,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
reset_buffer
()
prefill_count
=
seq_length
prefill_time
=
first_token_time
...
...
@@ -193,15 +221,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time
=
time
.
time
()
for
i
in
range
(
1
,
max_new_tokens
):
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
if
i
>
1
and
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment