Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
f35e8d41
"src/vscode:/vscode.git/clone" did not exist on "d44a8bfea49cda9b74960e7cfc61f16ae0e59808"
Commit
f35e8d41
authored
Mar 01, 2025
by
Atream
Browse files
support chunk prefill, support 139K context for 24G VRAM
parent
494469d4
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
228 additions
and
84 deletions
+228
-84
ktransformers/local_chat.py
ktransformers/local_chat.py
+3
-2
ktransformers/operators/attention.py
ktransformers/operators/attention.py
+1
-30
ktransformers/operators/flashinfer_wrapper.py
ktransformers/operators/flashinfer_wrapper.py
+25
-1
ktransformers/server/args.py
ktransformers/server/args.py
+1
-1
ktransformers/server/backend/args.py
ktransformers/server/backend/args.py
+1
-1
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+32
-21
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+0
-4
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+2
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+37
-22
test_prompt.txt
test_prompt.txt
+126
-1
No files found.
ktransformers/local_chat.py
View file @
f35e8d41
...
@@ -62,6 +62,7 @@ def local_chat(
...
@@ -62,6 +62,7 @@ def local_chat(
prompt_file
:
str
|
None
=
None
,
prompt_file
:
str
|
None
=
None
,
mode
:
str
=
"normal"
,
mode
:
str
=
"normal"
,
force_think
:
bool
=
False
,
force_think
:
bool
=
False
,
chunk_prefill_size
:
int
=
8192
):
):
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
...
@@ -170,12 +171,12 @@ def local_chat(
...
@@ -170,12 +171,12 @@ def local_chat(
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
if
system
!=
"Windows"
and
(
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
"DeepseekV3ForCausalLM"
)
and
flashinfer_enabled
and
get_compute_capability
()
>=
8
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
use_flashinfer_mla
=
True
,
num_heads
=
config
.
num_attention_heads
,
head_dim_ckv
=
config
.
kv_lora_rank
,
head_dim_kpe
=
config
.
qk_rope_head_dim
,
q_head_dim
=
config
.
qk_rope_head_dim
+
config
.
qk_nope_head_dim
)
)
else
:
else
:
generated
=
prefill_and_generate
(
generated
=
prefill_and_generate
(
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
model
,
tokenizer
,
input_tensor
.
cuda
(),
max_new_tokens
,
use_cuda_graph
,
mode
=
mode
,
force_think
=
force_think
,
chunk_prefill_size
=
chunk_prefill_size
,
)
)
...
...
ktransformers/operators/attention.py
View file @
f35e8d41
...
@@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
_chunk
(
def
forward_linux_flashinfer
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -512,35 +512,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
...
@@ -512,35 +512,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
forward_linux_flashinfer
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
cache_position
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
q_len
<=
self
.
chunck_size
or
not
self
.
absorb_for_prefill
:
return
self
.
forward_linux_flashinfer_chunk
(
hidden_states
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
,
use_cache
,
cache_position
,
**
kwargs
,
)
assert
False
def
forward_windows
(
def
forward_windows
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
ktransformers/operators/flashinfer_wrapper.py
View file @
f35e8d41
...
@@ -205,12 +205,13 @@ class MLAWrapperSingleton():
...
@@ -205,12 +205,13 @@ class MLAWrapperSingleton():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
max_batch_size
=
1
max_batch_size
=
1
max_pages
=
128
max_pages
=
128
page_size
=
64
page_size
=
64
num_heads
=
128
num_heads
=
128
kv_len
=
2069
kv_len
=
4023
q_len
=
1
q_len
=
1
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_nope
=
torch
.
randn
((
q_len
,
num_heads
,
512
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
q_pe
=
torch
.
randn
((
q_len
,
num_heads
,
64
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
@@ -242,6 +243,29 @@ if __name__ == "__main__":
...
@@ -242,6 +243,29 @@ if __name__ == "__main__":
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
print
(
attn_output
.
shape
)
print
(
attn_output
.
shape
)
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
attn_output
=
wrapper
.
run
(
q_nope
,
q_pe
,
ckv
,
k_pe
)
kv_len
=
6789
kv_len_arr
=
torch
.
tensor
([
kv_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
tensor
([
0
,
q_len
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
wrapper
.
plan
(
qo_indptr
,
None
,
None
,
kv_len_arr
,
128
,
512
,
64
,
page_size
,
192
**
(
-
0.5
),
torch
.
bfloat16
,
torch
.
bfloat16
,
)
graph
.
replay
()
k
=
(
k
=
(
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
torch
.
cat
([
ckv
,
k_pe
],
dim
=-
1
)
...
...
ktransformers/server/args.py
View file @
f35e8d41
...
@@ -24,13 +24,13 @@ class ArgumentParser:
...
@@ -24,13 +24,13 @@ class ArgumentParser:
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_prefill_size"
,
type
=
int
,
default
=
8192
)
# model configs
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_chunk_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_chunk_size
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
parser
.
add_argument
(
"--healing"
,
type
=
bool
,
default
=
self
.
cfg
.
healing
)
parser
.
add_argument
(
"--healing"
,
type
=
bool
,
default
=
self
.
cfg
.
healing
)
...
...
ktransformers/server/backend/args.py
View file @
f35e8d41
...
@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
...
@@ -23,7 +23,7 @@ class ConfigArgs(BaseModel):
max_batch_size
:
int
=
Field
(
max_batch_size
:
int
=
Field
(
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
)
)
max_
chunk_size
:
int
=
Field
(
chunk_
prefill_
size
:
int
=
Field
(
None
,
None
,
description
=
(
description
=
(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
...
...
ktransformers/server/backend/interfaces/ktransformers.py
View file @
f35e8d41
...
@@ -111,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
...
@@ -111,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
warm_uped
=
True
warm_uped
=
True
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
.
to
(
torch_device
),
self
.
current_ids
.
to
(
torch_device
),
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
)[
0
]
)[
0
]
...
@@ -167,41 +165,54 @@ class KTransformersInterface(TransformersInterface):
...
@@ -167,41 +165,54 @@ class KTransformersInterface(TransformersInterface):
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
else
:
logger
.
warning
(
f
"seq_length bigger than cache_lens, killed"
)
exit
(
0
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
def
chunk_prefill
(
input_ids
,
cache_position
):
if
flashinfer_enabled
:
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
MLAWrapperSingleton
.
need_plan_all
()
torch
.
cuda
.
set_device
(
device
)
if
self
.
use_static_cache
:
if
flashinfer_enabled
:
logits
=
self
.
model
(
MLAWrapperSingleton
.
need_plan_all
()
inputs_embeds
=
inputs_embeds
,
if
self
.
use_static_cache
:
cache_position
=
cache_position
,
logits
=
self
.
model
(
past_key_values
=
self
.
cache
,
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
,
cache_position
=
cache_position
,
use_cache
=
True
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
)[
0
]
use_cache
=
True
,
else
:
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return
logits
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_prefill_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_prefill_size
if
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
f35e8d41
...
@@ -242,12 +242,10 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -242,12 +242,10 @@ class TransformersInterface(BackendInterfaceBase):
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
):
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
,
self
.
current_ids
,
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
)[
0
]
)[
0
]
...
@@ -309,7 +307,6 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -309,7 +307,6 @@ class TransformersInterface(BackendInterfaceBase):
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
device
=
input_ids
.
device
device
=
input_ids
.
device
if
not
(
type
(
self
)
is
TransformersInterface
):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
input_ids
=
input_ids
.
to
(
"cpu"
)
...
@@ -321,7 +318,6 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -321,7 +318,6 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
)[
0
]
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
...
...
ktransformers/server/config/config.py
View file @
f35e8d41
...
@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
...
@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
...
...
ktransformers/util/utils.py
View file @
f35e8d41
...
@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
...
@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_prefill_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens
=
[]
tokens
=
[]
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
:
bool
=
True
):
if
cuda_graph_runner
is
None
:
if
cuda_graph_runner
is
None
:
use_cuda_graph
=
False
use_cuda_graph
=
False
if
use_cuda_graph
:
if
use_cuda_graph
:
...
@@ -152,24 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -152,24 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
return
next_token
return
next_token
torch
.
cuda
.
set_device
(
torch_device
)
# TODO: use CUDA Graph for chunk prefill, may get small improvement
with
torch
.
no_grad
():
def
chunk_prefill
(
inputs
,
cache_position
,
past_key_values
):
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
start_time
=
time
.
time
()
if
mode
==
"long_context"
:
if
mode
==
"long_context"
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
else
:
...
@@ -181,6 +165,20 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -181,6 +165,20 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits
=
model
(
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
return
logits
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
None
,
do_sample
=
True
None
,
do_sample
=
True
# change this to modify generate config
# change this to modify generate config
...
@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper
=
(
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
model
.
_get_logits_warper
(
generation_config
)
)
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
start_time
=
time
.
time
()
chunk_start
=
0
while
chunk_start
<
seq_length
:
chunk_end
=
min
(
chunk_start
+
chunk_prefill_size
,
seq_length
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
chunk_start
+=
chunk_prefill_size
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
if
generation_config
.
do_sample
:
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
if
use_flashinfer_mla
:
...
@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count
=
seq_length
prefill_count
=
seq_length
prefill_time
=
first_token_time
prefill_time
=
first_token_time
if
force_think
:
if
force_think
:
print
(
"<think>
\n
"
)
print
(
"<think>"
)
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
generated_ids
[:,
seq_length
]
=
next_token
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
...
@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped
=
True
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
...
...
test_prompt.txt
View file @
f35e8d41
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment