Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
f35e8d41
Commit
f35e8d41
authored
Mar 01, 2025
by
Atream
Browse files
support chunk prefill, support 139K context for 24G VRAM
parent
494469d4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
228 additions
and
84 deletions
+228
-84
ktransformers/local_chat.py
ktransformers/local_chat.py
+3
-2
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-30
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+25
-1
ktransformers/server/args.py
ktransformers/server/args.py
+1
-1
ktransformers/server/backend/args.py
ktransformers/server/backend/args.py
+1
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+32
-21
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+0
-4
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+2
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+37
-22
test_prompt.txt
test_prompt.txt
+126
-1
No files found.
ktransformers/local_chat.py
View file @
f35e8d41
...
...
@@ -62,6 +62,7 @@ def local_chat(
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
torch
.
set_grad_enabled
(
False
)
...
...
@@ -170,12 +171,12 @@ def local_chat(
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
else
:
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
...
...
ktransformers/operators/attention.py
View file @
f35e8d41
...
...
@@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
_chunk
(
def
forward_linux_flashinfer
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -512,35 +512,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
q_len
<=
self
.
chunck_size
or
not
self
.
absorb_for_prefill
:
return
self
.
forward_linux_flashinfer_chunk
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
assert
False
def
forward_windows
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
f35e8d41
...
...
@@ -205,12 +205,13 @@ class MLAWrapperSingleton():
if
__name__
==
"__main__"
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
max_batch_size
=
1
max_pages
=
128
page_size
=
64
num_heads
=
128
kv_len
=
2069
kv_len
=
4023
q_len
=
1
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
...
@@ -242,6 +243,29 @@ if __name__ == "__main__":
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
kv_len
=
6789
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
graph
.
replay
()
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
...
...
ktransformers/server/args.py
View file @
f35e8d41
...
...
@@ -24,13 +24,13 @@ class ArgumentParser:
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_prefill_size"
,
type
=
int
,
default
=
8192
)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_chunk_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_chunk_size
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
parser
.
add_argument
(
"--healing"
,
type
=
bool
,
default
=
self
.
cfg
.
healing
)
...
...
ktransformers/server/backend/args.py
View file @
f35e8d41
...
...
@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size
:
int
=
Field
(
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
)
max_
chunk_size
:
int
=
Field
(
chunk_
prefill_
size
:
int
=
Field
(
None
,
description
=
(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
f35e8d41
...
...
@@ -111,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
warm_uped
=
True
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
self
.
current_ids
.
to
(
torch_device
),
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
...
...
@@ -167,41 +165,54 @@ class KTransformersInterface(TransformersInterface):
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
else
:
logger
.
warning
(
f
"seq_length bigger than cache_lens, killed"
)
exit
(
0
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
def
chunk_prefill
(
input_ids
,
cache_position
):
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return
logits
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_prefill_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_prefill_size
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
f35e8d41
...
...
@@ -242,12 +242,10 @@ class TransformersInterface(BackendInterfaceBase):
def
decode_one_tokens
(
self
):
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
self
.
current_ids
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
...
...
@@ -309,7 +307,6 @@ class TransformersInterface(BackendInterfaceBase):
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
device
=
input_ids
.
device
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
...
...
@@ -321,7 +318,6 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
...
...
ktransformers/server/config/config.py
View file @
f35e8d41
...
...
@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
...
...
ktransformers/util/utils.py
View file @
f35e8d41
...
...
@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_prefill_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
...
@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens
=
[]
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
:
bool
=
True
):
if
cuda_graph_runner
is
None
:
use_cuda_graph
=
False
if
use_cuda_graph
:
...
...
@@ -152,24 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
return
next_token
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
start_time
=
time
.
time
()
# TODO: use CUDA Graph for chunk prefill, may get small improvement
def
chunk_prefill
(
inputs
,
cache_position
,
past_key_values
):
if
mode
==
"long_context"
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
...
...
@@ -181,6 +165,20 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
return
logits
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
None
,
do_sample
=
True
# change this to modify generate config
...
...
@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
start_time
=
time
.
time
()
chunk_start
=
0
while
chunk_start
<
seq_length
:
chunk_end
=
min
(
chunk_start
+
chunk_prefill_size
,
seq_length
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
chunk_start
+=
chunk_prefill_size
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
...
...
@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count
=
seq_length
prefill_time
=
first_token_time
if
force_think
:
print
(
"<think>
\n
"
)
print
(
"<think>"
)
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
...
...
@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
int
(
next_token
))
...
...
test_prompt.txt
View file @
f35e8d41
This source diff could not be displayed because it is too large. You can
view the blob
instead.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment