Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
5aee6c04
Commit
5aee6c04
authored
Mar 03, 2025
by
Azure
Browse files
Merge branch 'main' into develop-0.2.3
parents
216a63b8
48b98007
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
123 additions
and
101 deletions
+123
-101
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+53
-38
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+21
-15
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+2
-1
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-1
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+2
-0
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+2
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+40
-25
test_prompt.txt
test_prompt.txt
+0
-20
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
5aee6c04
...
...
@@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
...
...
@@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
do_sample
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
...
...
@@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
" belong to current model):"
)
optimize_and_load_gguf
(
self
.
model
,
optimize_config_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self
.
cache
=
StaticCache
(
...
...
@@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
dtype
=
self
.
model
.
dtype
,
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
...
...
@@ -110,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
warm_uped
=
True
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
self
.
current_ids
.
to
(
torch_device
),
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
...
...
@@ -128,10 +127,13 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
return
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
...
...
@@ -166,44 +168,57 @@ class KTransformersInterface(TransformersInterface):
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
else
:
logger
.
warning
(
f
"seq_length bigger than cache_lens, killed"
)
exit
(
0
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
def
chunk_prefill
(
input_ids
,
cache_position
):
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
flashinfer_enabled
:
MLAWrapperSingleton
.
need_plan_all
()
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return
logits
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_prefill_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_prefill_size
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
...
...
@@ -212,7 +227,7 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
ktransformers/server/backend/interfaces/transformers.py
View file @
5aee6c04
...
...
@@ -13,6 +13,7 @@ from transformers import (
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.utils.multi_timer
import
Profiler
from
torch.nn.attention
import
SDPBackend
import
torch
import
sys
,
os
from
..base
import
ThreadContext
,
BackendInterfaceBase
...
...
@@ -202,20 +203,23 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
or
temperature
==
0
:
temperature
=
self
.
model
.
generation_config
.
temperature
if
top_p
is
None
:
top_p
=
self
.
model
.
generation_config
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
temperature
=
self
.
args
.
temperature
,
top_p
=
top_p
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
except
:
self
.
logits_warper
=
(
...
...
@@ -239,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
def
decode_one_tokens
(
self
):
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
self
.
current_ids
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
use_cache
=
True
,
)[
0
]
...
...
@@ -255,7 +257,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
...
@@ -306,7 +308,6 @@ class TransformersInterface(BackendInterfaceBase):
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
device
=
input_ids
.
device
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
...
...
@@ -318,21 +319,26 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values
=
self
.
cache
,
return_dict
=
False
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
@
torch
.
no_grad
def
generate
(
self
):
self
.
args
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
if
(
self
.
args
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
()
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
args
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
with
torch
.
nn
.
attention
.
sdpa_kernel
(
backends
=
[
SDPBackend
.
FLASH_ATTENTION
,
SDPBackend
.
MATH
,
SDPBackend
.
EFFICIENT_ATTENTION
]):
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
...
...
@@ -359,7 +365,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
...
...
@@ -386,7 +392,7 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)
,
temperature
,
top_p
):
# output think token after prefill done
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
...
...
ktransformers/server/config/config.py
View file @
5aee6c04
...
...
@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
...
...
ktransformers/server/schemas/endpoints/chat.py
View file @
5aee6c04
...
...
@@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
messages
:
List
[
Message
]
model
:
str
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
5aee6c04
...
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model
:
str
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
...
...
ktransformers/util/custom_gguf.py
View file @
5aee6c04
...
...
@@ -27,6 +27,7 @@ import torch
import
KTransformersOps
from
.custom_loader
import
SafeTensorLoader
import
ctypes
import
math
class
GGMLQuantizationType
(
IntEnum
):
F32
=
0
...
...
@@ -230,7 +231,7 @@ class GGUFLoader:
shape
=
[
read_value
(
f
,
DATA_TYPES
[
"uint64"
])
for
_
in
range
(
shape_len
)]
ggml_type
=
read_value
(
f
,
DATA_TYPES
[
"uint32"
])
bad_offset
=
read_value
(
f
,
DATA_TYPES
[
"uint64"
])
n_elems
=
int
(
np
.
prod
(
shape
))
n_elems
=
int
(
math
.
prod
(
shape
))
block_size
,
type_size
=
GGML_QUANT_SIZES
[
ggml_type
]
n_bytes
=
n_elems
*
type_size
//
block_size
np_dims
=
tuple
(
reversed
(
shape
))
...
...
ktransformers/util/utils.py
View file @
5aee6c04
...
...
@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_prefill_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
...
@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens
=
[]
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
:
bool
=
True
):
if
cuda_graph_runner
is
None
:
use_cuda_graph
=
False
if
use_cuda_graph
:
...
...
@@ -152,25 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
return
next_token
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
start_time
=
time
.
time
()
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
# TODO: use CUDA Graph for chunk prefill, may get small improvement
def
chunk_prefill
(
inputs
,
cache_position
,
past_key_values
):
if
mode
==
"long_context"
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
...
...
@@ -182,9 +165,24 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
return
logits
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
None
,
max_length
=
max_new_tokens
,
do_sample
=
True
,
top_k
=
5
,
top_p
=
0.85
,
temperature
=
0.1
# change this to modify generate config
None
,
do_sample
=
True
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
try
:
# transformers==4.43
logits_warper
=
(
...
...
@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
start_time
=
time
.
time
()
chunk_start
=
0
while
chunk_start
<
seq_length
:
chunk_end
=
min
(
chunk_start
+
chunk_prefill_size
,
seq_length
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
chunk_start
+=
chunk_prefill_size
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
...
...
@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count
=
seq_length
prefill_time
=
first_token_time
if
force_think
:
print
(
"<think>
\n
"
)
print
(
"<think>"
)
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
...
...
@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
int
(
next_token
))
...
...
test_prompt.txt
deleted
100644 → 0
View file @
216a63b8
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment