Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
0bc43e02
Unverified
Commit
0bc43e02
authored
Mar 07, 2025
by
Atream
Committed by
GitHub
Mar 07, 2025
Browse files
Merge pull request #839 from kvcache-ai/fix-precision-flashinfer
fix flashinfer precision
parents
96d75d53
d453c320
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
148 additions
and
58 deletions
+148
-58
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-1
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+143
-54
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+1
-1
ktransformers/tests/humaneval/eval_api.py
ktransformers/tests/humaneval/eval_api.py
+2
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+1
-1
No files found.
ktransformers/operators/attention.py
View file @
0bc43e02
...
...
@@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
,
attention_ref
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
logger
=
logging
.
getLogger
(
"attention"
)
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
0bc43e02
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.
2
Version : 0.2.
3
'''
import
torch
import
os
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
flashinfer_enabled
=
False
...
...
@@ -17,7 +19,7 @@ except ImportError:
import
math
def
attention_ref
(
def
attention_ref
_torch
(
batch_size
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
@@ -122,7 +124,7 @@ class MLAWrapper():
if
kv_indices
is
None
:
assert
self
.
max_batch_size
==
1
kv_indices
=
self
.
kv_indices_buf
self
.
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
...
...
@@ -139,11 +141,6 @@ class MLAWrapper():
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
...
...
@@ -202,21 +199,59 @@ class MLAWrapperSingleton():
wrapper
.
kv_indptr_buf
[
1
]
=
max_pages
# assert max_batch_size=1 here.
wrapper
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
wrapper
.
wrapper
.
_kv_indices_buf
=
wrapper
.
kv_indices_buf
def
checksame
():
flashinfer_folder
=
"./flashinfer_output"
flashinfer_folder
=
"./kv_cache_flashinfer"
triton_folder
=
"./triton_output"
triton_folder
=
"./kv_cache_triton"
max_layer_id
=
1
max_forward_id
=
2
for
forward_id
in
range
(
0
,
19
):
print
(
"forward_id"
,
forward_id
)
for
layer_id
in
range
(
max_layer_id
):
print
(
layer_id
)
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
file_name
=
f
"layer_
{
layer_id
}
.pt"
flashinfer_path
=
os
.
path
.
join
(
flashinfer_folder
,
file_name
)
triton_path
=
os
.
path
.
join
(
triton_folder
,
file_name
)
if
not
os
.
path
.
exists
(
triton_path
):
print
(
f
"
{
file_name
}
not exist in
{
triton_folder
}
"
)
continue
if
not
os
.
path
.
exists
(
flashinfer_path
):
print
(
f
"
{
file_name
}
not exist in
{
flashinfer_folder
}
"
)
continue
flashinfer_tensor
=
torch
.
load
(
flashinfer_path
)[
1
:
2
,
:
62
]
#
triton_tensor
=
torch
.
load
(
triton_path
)[
1
:
2
,
:
62
]
#.squeeze(1)#
try
:
torch
.
testing
.
assert_close
(
flashinfer_tensor
,
triton_tensor
,
rtol
=
1e-9
,
atol
=
1e-9
)
except
AssertionError
as
e
:
print
(
e
)
if
__name__
==
"__main__"
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
#checksame()
#exit(0)
max_batch_size
=
1
max_pages
=
128
max_pages
=
64
page_size
=
64
num_heads
=
128
# warm-up
kv_len
=
4023
q_len
=
1
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
c
kv
=
torch
.
randn
((
max_pages
,
page_size
,
5
12
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
)
,
d
type
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope
_buf
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
_buf
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
kv
_buf
=
torch
.
randn
((
max_pages
,
page_size
,
5
76
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
,
k_pe
=
torch
.
split
(
kv_buf
,
[
512
,
64
]
,
d
im
=-
1
)
wrapper
=
MLAWrapperSingleton
.
get_instance
(
...
...
@@ -241,51 +276,105 @@ if __name__ == "__main__":
torch
.
bfloat16
,
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope
_buf
,
q_pe
_buf
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
kv_len
=
6789
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
attn_output
=
wrapper
.
run
(
q_nope_buf
,
q_pe_buf
,
ckv
,
k_pe
)
# warm-up finished
for
forward_id
in
range
(
0
,
1
):
print
(
"forward_id"
,
forward_id
)
for
layer_id
in
range
(
1
):
print
(
layer_id
)
flashinfer_folder
=
"./kv_cache_flashinfer"
forward_id
=
17
layer_id
=
0
file_name
=
f
"layer_
{
layer_id
}
.pt"
kv_cache_path
=
os
.
path
.
join
(
flashinfer_folder
,
file_name
)
flashinfer_folder
=
"./flashinfer_output"
q_len
=
1
kv_len
=
126
file_name
=
f
"layer_
{
layer_id
}
_forward_
{
forward_id
}
_q_nope.pt"
q_nope
=
torch
.
load
(
os
.
path
.
join
(
flashinfer_folder
,
file_name
)).
view
(
q_len
,
128
,
512
).
to
(
device
=
"cuda"
)
file_name
=
f
"layer_
{
layer_id
}
_forward_
{
forward_id
}
_q_pe.pt"
q_pe
=
torch
.
load
(
os
.
path
.
join
(
flashinfer_folder
,
file_name
)).
view
(
q_len
,
128
,
64
).
to
(
device
=
"cuda"
)
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
kv_cache
=
torch
.
load
(
kv_cache_path
).
to
(
device
=
"cuda"
)
pages
,
page_size
,
_
,
head_dim
=
kv_cache
.
shape
kv_cache
=
kv_cache
.
view
(
pages
,
page_size
,
head_dim
)
ckv
,
k_pe
=
torch
.
split
(
kv_cache
,
[
512
,
64
],
dim
=-
1
)
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
graph
.
replay
()
q_nope_buf
.
copy_
(
q_nope
)
q_pe_buf
.
copy_
(
q_pe
)
kv_buf
[:
pages
].
copy_
(
kv_cache
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
print
(
k
[:
kv_len
].
shape
)
print
(
v
[:
kv_len
].
shape
)
# ref_torch
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
attn_ref
,
lse_ref
=
attention_ref_torch
(
max_batch_size
,
q
,
k
[:
kv_len
],
v
[:
kv_len
],
False
,
192
**
(
-
0.5
)
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
# ref_triton
attn_logits
=
torch
.
empty
(
(
max_batch_size
,
num_heads
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
512
+
1
,
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
triton_ref
=
torch
.
zeros_like
(
q_nope
)
page_table
=
torch
.
arange
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ckv_with_pe
=
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
).
contiguous
().
view
(
pages
,
page_size
,
1
,
576
)
ckv
=
ckv
.
view
(
pages
,
page_size
,
1
,
512
)
decode_attention_fwd_grouped
(
q
,
ckv_with_pe
,
ckv
,
triton_ref
,
page_table
,
kv_len_arr
,
attn_logits
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
192
**
(
-
0.5
),
page_size
)
attn_ref
,
lse_ref
=
attention_ref
(
max_batch_size
,
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
),
k
[:
kv_len
],
v
[:
kv_len
],
True
,
192
**
(
-
0.5
)
)
print
(
attn_ref
.
shape
)
torch
.
testing
.
assert_close
(
attn_output
,
triton_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print
(
"test past"
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
0bc43e02
...
...
@@ -344,7 +344,7 @@ class TransformersInterface(BackendInterfaceBase):
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
(
self
.
model
.
config
.
qk_rope_head_dim
+
self
.
model
.
config
.
qk_nope_head_dim
)
**
(
-
0.5
)
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
...
...
ktransformers/tests/humaneval/eval_api.py
View file @
0bc43e02
...
...
@@ -85,7 +85,8 @@ def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"https://api.siliconflow.cn/v1/chat/completions"
,
help
=
"API URL"
)
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"Model Name"
)
parser
.
add_argument
(
"--out_path"
,
type
=
str
,
default
=
"results/api/eval_b.jsonl"
,
help
=
"Output Path"
)
parser
.
add_argument
(
"--auth_token"
,
type
=
str
,
default
=
None
,
help
=
"Auth Token"
)
...
...
ktransformers/util/utils.py
View file @
0bc43e02
...
...
@@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
)
,
torch
.
bfloat16
,
torch
.
bfloat16
)
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment