Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
c176e516
Commit
c176e516
authored
Feb 17, 2025
by
Xie Weiyu
Browse files
server mix mla
parent
038bc308
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
15 deletions
+23
-15
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+13
-11
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+10
-4
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
c176e516
...
...
@@ -15,7 +15,7 @@ from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from
ktransformers.local_chat
import
custom_models
,
default_optimize_rules
from
ktransformers.util.utils
import
get_device
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
pass
...
...
@@ -73,11 +73,13 @@ class KTransformersInterface(TransformersInterface):
self
.
_infer_lock
=
asyncio
.
Lock
()
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
,
i
):
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
if
self
.
args
.
use_cuda_graph
:
global
warm_uped
if
self
.
args
.
use_cuda_graph
and
(
(
warm_uped
==
True
and
int
(
i
)
==
1
)
or
(
warm_uped
==
False
and
int
(
i
)
==
2
)
):
warm_uped
=
True
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
...
...
@@ -91,14 +93,14 @@ class KTransformersInterface(TransformersInterface):
use_cache
=
True
,
)
if
hasattr
(
self
,
"cuda_graph_runner"
):
logits
=
self
.
cuda_graph_runner
(
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
)
self
.
cache
.
change_seq_length
(
1
)
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
if
hasattr
(
self
,
"cuda_graph_runner"
):
logits
=
self
.
cuda_graph_runner
(
self
.
current_ids
,
self
.
active_cache_position
.
unsqueeze
(
0
),
self
.
active_cache_position
)
self
.
cache
.
change_seq_length
(
1
)
torch
.
cuda
.
synchronize
()
logits
=
logits
[
0
,
-
1
,
:]
return
self
.
logits_to_token
(
logits
)
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
torch_device
)
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
c176e516
...
...
@@ -18,7 +18,7 @@ import sys, os
from
..base
import
ThreadContext
,
BackendInterfaceBase
from
ktransformers.server.config.log
import
logger
from
..args
import
ConfigArgs
,
default_args
from
ktransformers.operators.flashinfer_wrapper
import
flashinfer_enabled
,
MLAWrapperSingleton
# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
class
TextStreamer
:
...
...
@@ -219,7 +219,7 @@ class TransformersInterface(BackendInterfaceBase):
self
.
ever_generated_ids
.
add
(
last
)
return
last
def
decode_one_tokens
(
self
):
def
decode_one_tokens
(
self
,
i
):
if
self
.
use_static_cache
:
mask
=
torch
.
ones
((
1
,
self
.
seq_length
)).
to
(
self
.
args
.
device
)
logits
=
self
.
model
(
...
...
@@ -291,9 +291,15 @@ class TransformersInterface(BackendInterfaceBase):
@
torch
.
no_grad
def
generate
(
self
):
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
_
in
range
(
1
,
self
.
args
.
max_new_tokens
):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
False
,
enable_mem_efficient
=
False
,
enable_math
=
True
):
next_token
=
self
.
decode_one_tokens
()
if
i
>
1
and
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
(
self
.
model
.
config
.
qk_rope_head_dim
+
self
.
model
.
config
.
qk_nope_head_dim
)
**
(
-
0.5
),
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
next_token
=
self
.
decode_one_tokens
(
i
)
self
.
profiler
.
inc
(
"decode"
)
if
next_token
==
self
.
tokenizer
.
eos_token_id
:
assert
self
.
args
.
batch_size
==
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment