Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
5aee6c04
Commit
5aee6c04
authored
Mar 03, 2025
by
Azure
Browse files
Merge branch 'main' into develop-0.2.3
parents
216a63b8
48b98007
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
123 additions
and
101 deletions
+123
-101
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+53
-38
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+21
-15
ktransformers/server/config/config.py
ktransformers/server/config/config.py
+2
-1
ktransformers/server/schemas/endpoints/chat.py
ktransformers/server/schemas/endpoints/chat.py
+3
-1
ktransformers/server/schemas/legacy/completions.py
ktransformers/server/schemas/legacy/completions.py
+2
-0
ktransformers/util/custom_gguf.py
ktransformers/util/custom_gguf.py
+2
-1
ktransformers/util/utils.py
ktransformers/util/utils.py
+40
-25
test_prompt.txt
test_prompt.txt
+0
-20
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
5aee6c04
...
@@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
...
@@ -14,9 +14,9 @@ from ktransformers.models.custom_cache import StaticCache
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.util.cuda_graph_runner
import
CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
from
ktransformers.util.utils
import
get_device
from
typing
import
Optional
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
warm_uped
=
False
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
class
KTransformersThreadContext
(
TransformersThreadContext
):
...
@@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
...
@@ -29,6 +29,16 @@ class KTransformersInterface(TransformersInterface):
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
device
=
args
.
device
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
args
.
trust_remote_code
)
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
do_sample
=
True
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
torch
.
set_default_dtype
(
config
.
torch_dtype
)
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
if
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
config
.
_attn_implementation
=
"flash_attention_2"
config
.
_attn_implementation
=
"flash_attention_2"
...
@@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -49,7 +59,7 @@ class KTransformersInterface(TransformersInterface):
" belong to current model):"
" belong to current model):"
)
)
optimize_and_load_gguf
(
self
.
model
,
optimize_config_path
,
gguf_path
,
config
)
optimize_and_load_gguf
(
self
.
model
,
optimize_config_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
self
.
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self
.
cache
=
StaticCache
(
self
.
cache
=
StaticCache
(
...
@@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -60,16 +70,7 @@ class KTransformersInterface(TransformersInterface):
dtype
=
self
.
model
.
dtype
,
dtype
=
self
.
model
.
dtype
,
)
)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try
:
self
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
gen_config
=
GenerationConfig
(
max_length
=
128
,
temperature
=
0.7
,
top_p
=
0.9
,
do_sample
=
True
)
self
.
model
.
generation_config
=
gen_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
...
@@ -110,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
...
@@ -110,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
warm_uped
=
True
warm_uped
=
True
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
.
to
(
torch_device
),
self
.
current_ids
.
to
(
torch_device
),
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
)[
0
]
)[
0
]
...
@@ -128,10 +127,13 @@ class KTransformersInterface(TransformersInterface):
...
@@ -128,10 +127,13 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
return
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
device
=
"cuda:0"
if
device
==
"cuda"
else
device
...
@@ -166,44 +168,57 @@ class KTransformersInterface(TransformersInterface):
...
@@ -166,44 +168,57 @@ class KTransformersInterface(TransformersInterface):
self
.
ever_generated_ids
.
clear
()
self
.
ever_generated_ids
.
clear
()
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
self
.
profiler
.
set_counter
(
"prefill"
,
input_ids_length
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
logger
.
debug
(
f
"generate_ids:
{
self
.
generated_ids
.
shape
}
"
)
former_seq_length
=
self
.
seq_length
former_seq_length
=
self
.
seq_length
self
.
seq_length
+=
input_ids_length
self
.
seq_length
+=
input_ids_length
expected_length
=
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
expected_length
=
min
(
self
.
seq_length
+
self
.
args
.
max_new_tokens
+
1
,
self
.
args
.
cache_lens
)
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
delta_length
=
expected_length
-
self
.
generated_ids
.
shape
[
-
1
]
if
delta_length
>
0
:
if
delta_length
>
0
:
new_generate_ids
=
torch
.
zeros
(
new_generate_ids
=
torch
.
zeros
(
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
self
.
args
.
batch_size
,
delta_length
,
dtype
=
torch
.
int
,
device
=
self
.
args
.
device
)
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
self
.
generated_ids
=
torch
.
cat
([
self
.
generated_ids
,
new_generate_ids
],
dim
=-
1
)
else
:
logger
.
warning
(
f
"seq_length bigger than cache_lens, killed"
)
exit
(
0
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
logger
.
debug
(
f
"cache position:
{
former_seq_length
}
to
{
self
.
seq_length
}
"
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
device
)
if
not
(
type
(
self
)
is
TransformersInterface
):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
def
chunk_prefill
(
input_ids
,
cache_position
):
if
flashinfer_enabled
:
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
MLAWrapperSingleton
.
need_plan_all
()
torch
.
cuda
.
set_device
(
device
)
if
self
.
use_static_cache
:
if
flashinfer_enabled
:
logits
=
self
.
model
(
MLAWrapperSingleton
.
need_plan_all
()
inputs_embeds
=
inputs_embeds
,
if
self
.
use_static_cache
:
cache_position
=
cache_position
,
logits
=
self
.
model
(
past_key_values
=
self
.
cache
,
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
,
cache_position
=
cache_position
,
use_cache
=
True
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
)[
0
]
use_cache
=
True
,
else
:
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
return
logits
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_prefill_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_prefill_size
if
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
MLAWrapperSingleton
.
reset_buffer
()
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
...
@@ -212,7 +227,7 @@ class KTransformersInterface(TransformersInterface):
...
@@ -212,7 +227,7 @@ class KTransformersInterface(TransformersInterface):
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
return
torch
.
tensor
([
self
.
seq_length
-
1
],
device
=
device
)
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
async
with
self
.
_infer_lock
:
async
with
self
.
_infer_lock
:
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
):
async
for
v
in
super
().
inference
(
local_messages
,
thread_id
,
temperature
,
top_p
):
yield
v
yield
v
ktransformers/server/backend/interfaces/transformers.py
View file @
5aee6c04
...
@@ -13,6 +13,7 @@ from transformers import (
...
@@ -13,6 +13,7 @@ from transformers import (
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.config.config
import
Config
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.utils.multi_timer
import
Profiler
from
ktransformers.server.utils.multi_timer
import
Profiler
from
torch.nn.attention
import
SDPBackend
import
torch
import
torch
import
sys
,
os
import
sys
,
os
from
..base
import
ThreadContext
,
BackendInterfaceBase
from
..base
import
ThreadContext
,
BackendInterfaceBase
...
@@ -202,20 +203,23 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -202,20 +203,23 @@ class TransformersInterface(BackendInterfaceBase):
self
.
seq_length
+=
1
self
.
seq_length
+=
1
return
self
.
streamer
.
put
(
new_tokens
)
return
self
.
streamer
.
put
(
new_tokens
)
def
prepare_logits_wrapper
(
self
,
inputs
,
device
):
def
prepare_logits_wrapper
(
self
,
inputs
,
device
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
if
temperature
is
None
or
temperature
==
0
:
temperature
=
self
.
model
.
generation_config
.
temperature
if
top_p
is
None
:
top_p
=
self
.
model
.
generation_config
.
top_p
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
do_sample
=
True
,
top_k
=
self
.
args
.
top_k
,
top_k
=
self
.
args
.
top_k
,
top_p
=
self
.
args
.
top_p
,
top_p
=
top_p
,
temperature
=
self
.
args
.
temperature
,
temperature
=
temperature
,
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
repetition_penalty
=
self
.
args
.
repetition_penalty
# change this to modify generate config
)
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
generation_config
=
generation_config
try
:
# transformers==4.43
try
:
# transformers==4.43
self
.
logits_warper
=
(
self
.
logits_warper
=
(
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
self
.
model
.
_get_logits_warper
(
generation_config
,
device
=
device
)
)
)
except
:
except
:
self
.
logits_warper
=
(
self
.
logits_warper
=
(
...
@@ -239,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -239,12 +243,10 @@ class TransformersInterface(BackendInterfaceBase):
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
):
if
self
.
use_static_cache
:
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
logits
=
self
.
model
(
self
.
current_ids
,
self
.
current_ids
,
cache_position
=
self
.
active_cache_position
,
cache_position
=
self
.
active_cache_position
,
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
attention_mask
=
mask
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
)[
0
]
)[
0
]
...
@@ -255,7 +257,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -255,7 +257,7 @@ class TransformersInterface(BackendInterfaceBase):
return
self
.
logits_to_token
(
logits
)
return
self
.
logits_to_token
(
logits
)
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
...
@@ -306,7 +308,6 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -306,7 +308,6 @@ class TransformersInterface(BackendInterfaceBase):
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
cache_position
=
torch
.
arange
(
former_seq_length
,
self
.
seq_length
,
device
=
self
.
args
.
device
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
self
.
generated_ids
[:,
cache_position
]
=
input_ids
.
to
(
self
.
args
.
device
).
to
(
torch
.
int
)
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
device
=
input_ids
.
device
device
=
input_ids
.
device
if
not
(
type
(
self
)
is
TransformersInterface
):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
input_ids
=
input_ids
.
to
(
"cpu"
)
...
@@ -318,21 +319,26 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -318,21 +319,26 @@ class TransformersInterface(BackendInterfaceBase):
past_key_values
=
self
.
cache
,
past_key_values
=
self
.
cache
,
return_dict
=
False
,
return_dict
=
False
,
use_cache
=
True
,
use_cache
=
True
,
attention_mask
=
mask
,
)[
0
]
)[
0
]
else
:
else
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
return_dict
=
False
)[
0
]
self
.
prepare_logits_wrapper
(
input_ids
,
device
)
self
.
prepare_logits_wrapper
(
input_ids
,
device
,
temperature
,
top_p
)
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
next_token
=
self
.
logits_to_token
(
logits
[
0
,
-
1
,
:])
yield
self
.
append_new_tokens
(
next_token
)
yield
self
.
append_new_tokens
(
next_token
)
@
torch
.
no_grad
@
torch
.
no_grad
def
generate
(
self
):
def
generate
(
self
):
self
.
args
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
if
(
self
.
args
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
()
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
args
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
nn
.
attention
.
sdpa_kernel
(
backends
=
[
SDPBackend
.
FLASH_ATTENTION
,
SDPBackend
.
MATH
,
SDPBackend
.
EFFICIENT_ATTENTION
]):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
if
flashinfer_enabled
:
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
...
@@ -359,7 +365,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -359,7 +365,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
last_request_id
=
thread_id
self
.
last_request_id
=
thread_id
return
True
return
True
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
):
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
self
.
streamer
.
reset
()
self
.
streamer
.
reset
()
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
self
.
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
if
isinstance
(
local_messages
,
List
):
...
@@ -386,7 +392,7 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -386,7 +392,7 @@ class TransformersInterface(BackendInterfaceBase):
print
(
think
,
end
=
""
,
flush
=
True
)
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
yield
think
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)):
for
t
in
self
.
prefill
(
input_ids
,
self
.
check_is_new
(
thread_id
)
,
temperature
,
top_p
):
# output think token after prefill done
# output think token after prefill done
if
t
is
not
None
:
if
t
is
not
None
:
print
(
t
,
end
=
""
,
flush
=
True
)
print
(
t
,
end
=
""
,
flush
=
True
)
...
...
ktransformers/server/config/config.py
View file @
5aee6c04
...
@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
...
@@ -105,7 +105,8 @@ class Config(metaclass=Singleton):
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
total_context
=
self
.
model
.
get
(
"total_context"
,
2
**
18
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_batch_size
=
self
.
model
.
get
(
"max_batch_size"
,
20
if
self
.
paged
else
1
)
self
.
max_chunk_size
=
self
.
model
.
get
(
"max_chunk_size"
,
2048
)
self
.
chunk_prefill_size
=
self
.
model
.
get
(
"chunk_prefill_size"
,
8192
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
max_new_tokens
=
self
.
model
.
get
(
"max_new_tokens"
,
2000
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
json_mode
=
self
.
model
.
get
(
"json_mode"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
self
.
healing
=
self
.
model
.
get
(
"healing"
,
False
)
...
...
ktransformers/server/schemas/endpoints/chat.py
View file @
5aee6c04
...
@@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
...
@@ -25,7 +25,9 @@ class ChatCompletionCreate(BaseModel):
messages
:
List
[
Message
]
messages
:
List
[
Message
]
model
:
str
model
:
str
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
return
[
m
.
to_tokenizer_message
()
for
m
in
self
.
messages
]
...
...
ktransformers/server/schemas/legacy/completions.py
View file @
5aee6c04
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
...
@@ -9,6 +9,8 @@ class CompletionCreate(BaseModel):
model
:
str
model
:
str
prompt
:
str
|
List
[
str
]
prompt
:
str
|
List
[
str
]
stream
:
bool
=
False
stream
:
bool
=
False
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
def
get_tokenizer_messages
(
self
):
def
get_tokenizer_messages
(
self
):
if
isinstance
(
self
.
prompt
,
List
):
if
isinstance
(
self
.
prompt
,
List
):
...
...
ktransformers/util/custom_gguf.py
View file @
5aee6c04
...
@@ -27,6 +27,7 @@ import torch
...
@@ -27,6 +27,7 @@ import torch
import
KTransformersOps
import
KTransformersOps
from
.custom_loader
import
SafeTensorLoader
from
.custom_loader
import
SafeTensorLoader
import
ctypes
import
ctypes
import
math
class
GGMLQuantizationType
(
IntEnum
):
class
GGMLQuantizationType
(
IntEnum
):
F32
=
0
F32
=
0
...
@@ -230,7 +231,7 @@ class GGUFLoader:
...
@@ -230,7 +231,7 @@ class GGUFLoader:
shape
=
[
read_value
(
f
,
DATA_TYPES
[
"uint64"
])
for
_
in
range
(
shape_len
)]
shape
=
[
read_value
(
f
,
DATA_TYPES
[
"uint64"
])
for
_
in
range
(
shape_len
)]
ggml_type
=
read_value
(
f
,
DATA_TYPES
[
"uint32"
])
ggml_type
=
read_value
(
f
,
DATA_TYPES
[
"uint32"
])
bad_offset
=
read_value
(
f
,
DATA_TYPES
[
"uint64"
])
bad_offset
=
read_value
(
f
,
DATA_TYPES
[
"uint64"
])
n_elems
=
int
(
np
.
prod
(
shape
))
n_elems
=
int
(
math
.
prod
(
shape
))
block_size
,
type_size
=
GGML_QUANT_SIZES
[
ggml_type
]
block_size
,
type_size
=
GGML_QUANT_SIZES
[
ggml_type
]
n_bytes
=
n_elems
*
type_size
//
block_size
n_bytes
=
n_elems
*
type_size
//
block_size
np_dims
=
tuple
(
reversed
(
shape
))
np_dims
=
tuple
(
reversed
(
shape
))
...
...
ktransformers/util/utils.py
View file @
5aee6c04
...
@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
...
@@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module
.
load
()
module
.
load
()
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
def
prefill_and_generate
(
model
,
tokenizer
,
inputs
,
max_new_tokens
=
10000
,
use_cuda_graph
:
bool
=
True
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
use_flashinfer_mla
=
False
,
mode
=
'normal'
,
force_think
:
bool
=
False
,
chunk_prefill_size
=
16384
,
use_flashinfer_mla
=
False
,
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
num_heads
=
None
,
head_dim_ckv
=
None
,
head_dim_kpe
=
None
,
q_head_dim
=
None
):
import
os
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
...
@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
tokens
=
[]
tokens
=
[]
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
:
bool
=
True
):
def
decode_one_tokens
(
cuda_graph_runner
,
cur_token
,
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
:
bool
=
True
):
if
cuda_graph_runner
is
None
:
if
cuda_graph_runner
is
None
:
use_cuda_graph
=
False
use_cuda_graph
=
False
if
use_cuda_graph
:
if
use_cuda_graph
:
...
@@ -152,25 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -152,25 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
return
next_token
return
next_token
torch
.
cuda
.
set_device
(
torch_device
)
# TODO: use CUDA Graph for chunk prefill, may get small improvement
with
torch
.
no_grad
():
def
chunk_prefill
(
inputs
,
cache_position
,
past_key_values
):
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
start_time
=
time
.
time
()
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
)).
to
(
torch_device
)
if
mode
==
"long_context"
:
if
mode
==
"long_context"
:
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
inputs_embeds
=
model
.
model
.
embed_tokens
(
inputs
.
to
(
"cpu"
))
else
:
else
:
...
@@ -182,9 +165,24 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -182,9 +165,24 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits
=
model
(
logits
=
model
(
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
inputs_embeds
=
inputs_embeds
,
cache_position
=
cache_position
,
past_key_values
=
past_key_values
,
return_dict
=
False
,
use_cache
=
True
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
)[
0
][:,
-
1
,:].
unsqueeze
(
0
).
clone
().
to
(
torch_device
)
return
logits
torch
.
cuda
.
set_device
(
torch_device
)
with
torch
.
no_grad
():
stream
=
TextStreamer
(
tokenizer
)
if
mode
!=
'long_context'
:
past_key_values
=
StaticCache
(
config
=
model
.
config
,
max_batch_size
=
1
,
max_cache_len
=
seq_length
+
max_new_tokens
,
device
=
device_map
,
dtype
=
model
.
dtype
)
else
:
past_key_values
=
None
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
generation_config
,
model_kwargs
=
model
.
_prepare_generation_config
(
None
,
max_length
=
max_new_tokens
,
None
,
do_sample
=
True
do_sample
=
True
,
top_k
=
5
,
top_p
=
0.85
,
temperature
=
0.1
# change this to modify generate config
# change this to modify generate config
#top_k=5, top_p=0.85, temperature=0.1
)
)
try
:
# transformers==4.43
try
:
# transformers==4.43
logits_warper
=
(
logits_warper
=
(
...
@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
logits_warper
=
(
logits_warper
=
(
model
.
_get_logits_warper
(
generation_config
)
model
.
_get_logits_warper
(
generation_config
)
)
)
cache_position
=
torch
.
arange
(
seq_length
,
device
=
torch_device
,
dtype
=
torch
.
int32
)
generated_ids
=
torch
.
zeros
(
batch_size
,
seq_length
+
max_new_tokens
+
1
,
dtype
=
torch
.
int
,
device
=
torch_device
)
generated_ids
[:,
cache_position
]
=
inputs
.
to
(
torch_device
).
to
(
torch
.
int
)
start_time
=
time
.
time
()
chunk_start
=
0
while
chunk_start
<
seq_length
:
chunk_end
=
min
(
chunk_start
+
chunk_prefill_size
,
seq_length
)
if
past_key_values
!=
None
:
past_key_values
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
inputs
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
],
past_key_values
)
chunk_start
+=
chunk_prefill_size
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
next_token_scores
=
logits_warper
(
inputs
,
logits
[:,
-
1
,
:])
if
generation_config
.
do_sample
:
if
generation_config
.
do_sample
:
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
probs
=
nn
.
functional
.
softmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
next_token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
else
:
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
next_token
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
first_token_time
=
time
.
time
()
-
start_time
first_token_time
=
time
.
time
()
-
start_time
if
use_flashinfer_mla
:
if
use_flashinfer_mla
:
...
@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
prefill_count
=
seq_length
prefill_count
=
seq_length
prefill_time
=
first_token_time
prefill_time
=
first_token_time
if
force_think
:
if
force_think
:
print
(
"<think>
\n
"
)
print
(
"<think>"
)
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
print
(
stream
.
put
(
next_token
.
item
()),
end
=
""
,
flush
=
True
)
generated_ids
[:,
seq_length
]
=
next_token
generated_ids
[:,
seq_length
]
=
next_token
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
...
@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
...
@@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped
=
True
warm_uped
=
True
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
=
CUDAGraphRunner
()
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
cuda_graph_runner
.
capture
(
model
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
torch_device
,
return_dict
=
False
,
use_cache
=
True
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
use_cuda_graph
).
to
(
torch_device
)
next_token
=
decode_one_tokens
(
cuda_graph_runner
,
next_token
.
unsqueeze
(
0
),
position_ids
,
cache_position
,
past_key_values
,
logits_warper
,
generation_config
,
use_cuda_graph
).
to
(
torch_device
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
inputs
=
torch
.
cat
((
inputs
,
next_token
.
unsqueeze
(
0
)),
dim
=-
1
)
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
generated_ids
[:,
cache_position
]
=
next_token
.
int
()
tokens
.
append
(
int
(
next_token
))
tokens
.
append
(
int
(
next_token
))
...
...
test_prompt.txt
deleted
100644 → 0
View file @
216a63b8
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment