Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
d453c320
You need to sign in or sign up before continuing.
Commit
d453c320
authored
Mar 07, 2025
by
Atream
Browse files
fix flashinfer precision
parent
96d75d53
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
148 additions
and
58 deletions
+148
-58
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-1
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+143
-54
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+1
-1
ktransformers/tests/humaneval/eval_api.py
ktransformers/tests/humaneval/eval_api.py
+2
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+1
-1
No files found.
ktransformers/operators/attention.py
View file @
d453c320
...
@@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
...
@@ -25,7 +25,7 @@ from ktransformers.operators.triton_attention import decode_attention_fwd_groupe
import
os
import
os
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
if
flashinfer_enabled
:
if
flashinfer_enabled
:
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
,
attention_ref
from
ktransformers.operators.flashinfer_wrapper
import
MLAWrapperSingleton
logger
=
logging
.
getLogger
(
"attention"
)
logger
=
logging
.
getLogger
(
"attention"
)
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
d453c320
'''
'''
Description : flashinfer MLA wrapper
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Author : Boxin Zhang
Version : 0.2.
2
Version : 0.2.
3
'''
'''
import
torch
import
torch
import
os
from
ktransformers.operators.triton_attention
import
decode_attention_fwd_grouped
flashinfer_enabled
=
False
flashinfer_enabled
=
False
...
@@ -17,7 +19,7 @@ except ImportError:
...
@@ -17,7 +19,7 @@ except ImportError:
import
math
import
math
def
attention_ref
(
def
attention_ref
_torch
(
batch_size
,
batch_size
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
@@ -122,7 +124,7 @@ class MLAWrapper():
...
@@ -122,7 +124,7 @@ class MLAWrapper():
if
kv_indices
is
None
:
if
kv_indices
is
None
:
assert
self
.
max_batch_size
==
1
assert
self
.
max_batch_size
==
1
kv_indices
=
self
.
kv_indices_buf
kv_indices
=
self
.
kv_indices_buf
self
.
wrapper
.
plan
(
self
.
wrapper
.
plan
(
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
...
@@ -139,11 +141,6 @@ class MLAWrapper():
...
@@ -139,11 +141,6 @@ class MLAWrapper():
)
)
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
def
run
(
self
,
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
False
):
#print("run")
#print(self.wrapper._qo_indptr_buf)
#print(self.wrapper._kv_indptr_buf)
#print(self.wrapper._kv_indices_buf)
#print(self.wrapper._kv_len_arr_buf)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
return
self
.
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
,
return_lse
=
return_lse
)
class
MLAWrapperSingleton
():
class
MLAWrapperSingleton
():
...
@@ -202,21 +199,59 @@ class MLAWrapperSingleton():
...
@@ -202,21 +199,59 @@ class MLAWrapperSingleton():
wrapper
.
kv_indptr_buf
[
1
]
=
max_pages
# assert max_batch_size=1 here.
wrapper
.
kv_indptr_buf
[
1
]
=
max_pages
# assert max_batch_size=1 here.
wrapper
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
wrapper
.
kv_indices_buf
=
torch
.
arange
(
0
,
max_pages
,
dtype
=
torch
.
int32
,
device
=
device
)
wrapper
.
wrapper
.
_kv_indices_buf
=
wrapper
.
kv_indices_buf
wrapper
.
wrapper
.
_kv_indices_buf
=
wrapper
.
kv_indices_buf
def
checksame
():
flashinfer_folder
=
"./flashinfer_output"
flashinfer_folder
=
"./kv_cache_flashinfer"
triton_folder
=
"./triton_output"
triton_folder
=
"./kv_cache_triton"
max_layer_id
=
1
max_forward_id
=
2
for
forward_id
in
range
(
0
,
19
):
print
(
"forward_id"
,
forward_id
)
for
layer_id
in
range
(
max_layer_id
):
print
(
layer_id
)
#file_name = f"layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#file_name = f"layer_{layer_id}_forward_{forward_id}_q_pe.pt"
file_name
=
f
"layer_
{
layer_id
}
.pt"
flashinfer_path
=
os
.
path
.
join
(
flashinfer_folder
,
file_name
)
triton_path
=
os
.
path
.
join
(
triton_folder
,
file_name
)
if
not
os
.
path
.
exists
(
triton_path
):
print
(
f
"
{
file_name
}
not exist in
{
triton_folder
}
"
)
continue
if
not
os
.
path
.
exists
(
flashinfer_path
):
print
(
f
"
{
file_name
}
not exist in
{
flashinfer_folder
}
"
)
continue
flashinfer_tensor
=
torch
.
load
(
flashinfer_path
)[
1
:
2
,
:
62
]
#
triton_tensor
=
torch
.
load
(
triton_path
)[
1
:
2
,
:
62
]
#.squeeze(1)#
try
:
torch
.
testing
.
assert_close
(
flashinfer_tensor
,
triton_tensor
,
rtol
=
1e-9
,
atol
=
1e-9
)
except
AssertionError
as
e
:
print
(
e
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
#checksame()
#exit(0)
max_batch_size
=
1
max_batch_size
=
1
max_pages
=
128
max_pages
=
64
page_size
=
64
page_size
=
64
num_heads
=
128
num_heads
=
128
# warm-up
kv_len
=
4023
kv_len
=
4023
q_len
=
1
q_len
=
1
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope
_buf
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
_buf
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
c
kv
=
torch
.
randn
((
max_pages
,
page_size
,
5
12
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
kv
_buf
=
torch
.
randn
((
max_pages
,
page_size
,
5
76
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k_pe
=
torch
.
randn
((
max_pages
,
page_size
,
64
)
,
d
type
=
torch
.
bfloat16
,
device
=
"cuda"
)
ckv
,
k_pe
=
torch
.
split
(
kv_buf
,
[
512
,
64
]
,
d
im
=-
1
)
wrapper
=
MLAWrapperSingleton
.
get_instance
(
wrapper
=
MLAWrapperSingleton
.
get_instance
(
...
@@ -241,51 +276,105 @@ if __name__ == "__main__":
...
@@ -241,51 +276,105 @@ if __name__ == "__main__":
torch
.
bfloat16
,
torch
.
bfloat16
,
)
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope
_buf
,
q_pe
_buf
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
print
(
attn_output
.
shape
)
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
with
torch
.
cuda
.
graph
(
graph
):
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope_buf
,
q_pe_buf
,
ckv
,
k_pe
)
# warm-up finished
kv_len
=
6789
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
forward_id
in
range
(
0
,
1
):
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
print
(
"forward_id"
,
forward_id
)
wrapper
.
plan
(
for
layer_id
in
range
(
1
):
qo_indptr
,
print
(
layer_id
)
None
,
flashinfer_folder
=
"./kv_cache_flashinfer"
None
,
forward_id
=
17
kv_len_arr
,
layer_id
=
0
128
,
file_name
=
f
"layer_
{
layer_id
}
.pt"
512
,
kv_cache_path
=
os
.
path
.
join
(
flashinfer_folder
,
file_name
)
64
,
flashinfer_folder
=
"./flashinfer_output"
page_size
,
192
**
(
-
0.5
),
q_len
=
1
torch
.
bfloat16
,
kv_len
=
126
torch
.
bfloat16
,
file_name
=
f
"layer_
{
layer_id
}
_forward_
{
forward_id
}
_q_nope.pt"
)
q_nope
=
torch
.
load
(
os
.
path
.
join
(
flashinfer_folder
,
file_name
)).
view
(
q_len
,
128
,
512
).
to
(
device
=
"cuda"
)
file_name
=
f
"layer_
{
layer_id
}
_forward_
{
forward_id
}
_q_pe.pt"
q_pe
=
torch
.
load
(
os
.
path
.
join
(
flashinfer_folder
,
file_name
)).
view
(
q_len
,
128
,
64
).
to
(
device
=
"cuda"
)
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
kv_cache
=
torch
.
load
(
kv_cache_path
).
to
(
device
=
"cuda"
)
pages
,
page_size
,
_
,
head_dim
=
kv_cache
.
shape
kv_cache
=
kv_cache
.
view
(
pages
,
page_size
,
head_dim
)
ckv
,
k_pe
=
torch
.
split
(
kv_cache
,
[
512
,
64
],
dim
=-
1
)
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
None
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
graph
.
replay
()
q_nope_buf
.
copy_
(
q_nope
)
q_pe_buf
.
copy_
(
q_pe
)
kv_buf
[:
pages
].
copy_
(
kv_cache
)
k
=
(
torch
.
cuda
.
synchronize
()
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
graph
.
replay
()
.
view
(
-
1
,
1
,
512
+
64
)
torch
.
cuda
.
synchronize
()
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
print
(
k
[:
kv_len
].
shape
)
# ref_torch
print
(
v
[:
kv_len
].
shape
)
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
.
view
(
-
1
,
1
,
512
+
64
)
.
repeat_interleave
(
num_heads
,
dim
=
1
)
)
v
=
ckv
.
view
(
-
1
,
1
,
512
).
repeat_interleave
(
num_heads
,
dim
=
1
)
attn_ref
,
lse_ref
=
attention_ref_torch
(
max_batch_size
,
q
,
k
[:
kv_len
],
v
[:
kv_len
],
False
,
192
**
(
-
0.5
)
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
# ref_triton
attn_logits
=
torch
.
empty
(
(
max_batch_size
,
num_heads
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
512
+
1
,
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
triton_ref
=
torch
.
zeros_like
(
q_nope
)
page_table
=
torch
.
arange
(
max_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ckv_with_pe
=
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
).
contiguous
().
view
(
pages
,
page_size
,
1
,
576
)
ckv
=
ckv
.
view
(
pages
,
page_size
,
1
,
512
)
decode_attention_fwd_grouped
(
q
,
ckv_with_pe
,
ckv
,
triton_ref
,
page_table
,
kv_len_arr
,
attn_logits
,
4
,
#num_kv_splits # follow vLLM, fix it TODO
192
**
(
-
0.5
),
page_size
)
attn_ref
,
lse_ref
=
attention_ref
(
torch
.
testing
.
assert_close
(
attn_output
,
triton_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
max_batch_size
,
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
),
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
k
[:
kv_len
],
#ktrans_output = torch.load(file_name)
v
[:
kv_len
],
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
True
,
print
(
"test past"
)
192
**
(
-
0.5
)
)
print
(
attn_ref
.
shape
)
torch
.
testing
.
assert_close
(
attn_output
,
attn_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
"test past"
)
\ No newline at end of file
ktransformers/server/backend/interfaces/transformers.py
View file @
d453c320
...
@@ -344,7 +344,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -344,7 +344,7 @@ class TransformersInterface(BackendInterfaceBase):
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
(
self
.
model
.
config
.
qk_rope_head_dim
+
self
.
model
.
config
.
qk_nope_head_dim
)
**
(
-
0.5
)
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
()
next_token
=
self
.
decode_one_tokens
()
self
.
profiler
.
inc
(
"decode"
)
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
if
next_token
==
self
.
tokenizer
.
eos_token_id
or
"<|im_end|>"
==
self
.
tokenizer
.
decode
(
next_token
):
...
...
ktransformers/tests/humaneval/eval_api.py
View file @
d453c320
...
@@ -85,7 +85,8 @@ def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file,
...
@@ -85,7 +85,8 @@ def main(output_path, api_url, model_name, auth_token, format_tabs,problem_file,
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"API Generate Tester"
)
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"https://api.siliconflow.cn/v1/chat/completions"
,
help
=
"API URL"
)
#parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
parser
.
add_argument
(
"--api_url"
,
type
=
str
,
default
=
"http://localhost:10002/v1/chat/completions"
,
help
=
"API URL"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"Model Name"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"Pro/deepseek-ai/DeepSeek-V3"
,
help
=
"Model Name"
)
parser
.
add_argument
(
"--out_path"
,
type
=
str
,
default
=
"results/api/eval_b.jsonl"
,
help
=
"Output Path"
)
parser
.
add_argument
(
"--out_path"
,
type
=
str
,
default
=
"results/api/eval_b.jsonl"
,
help
=
"Output Path"
)
parser
.
add_argument
(
"--auth_token"
,
type
=
str
,
default
=
None
,
help
=
"Auth Token"
)
parser
.
add_argument
(
"--auth_token"
,
type
=
str
,
default
=
None
,
help
=
"Auth Token"
)
...
...
ktransformers/util/utils.py
View file @
d453c320
...
@@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -239,7 +239,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if
use_flashinfer_mla
:
if
use_flashinfer_mla
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
position_ids
.
squeeze
(
1
)
+
1
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
num_heads
,
head_dim_ckv
,
head_dim_kpe
,
past_key_values
.
page_size
,
q_head_dim
**
(
-
0.5
)
,
torch
.
bfloat16
,
torch
.
bfloat16
)
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
torch
.
bfloat16
,
torch
.
bfloat16
)
global
warm_uped
global
warm_uped
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
if
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
warm_uped
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment