Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ox696c
ktransformers
Commits
ee24eb8d
"tests/vscode:/vscode.git/clone" did not exist on "b6167d59878f2ae396c6236b43068b235deb67fc"
Commit
ee24eb8d
authored
Feb 17, 2025
by
ceerrep
Browse files
fix: fix server for triton kernel
parent
bb1cadff
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
4 deletions
+8
-4
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+8
-1
ktransformers/server/main.py
ktransformers/server/main.py
+0
-3
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
ee24eb8d
...
...
@@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from
ktransformers.util.utils
import
get_device
warm_uped
=
False
class
KTransformersThreadContext
(
TransformersThreadContext
):
pass
...
...
@@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface):
self
.
_infer_lock
=
asyncio
.
Lock
()
def
decode_one_tokens
(
self
):
global
warm_uped
device_map
=
self
.
model
.
gguf_loader
.
tensor_device_map
torch_device
=
get_device
(
"blk.0.self_attn"
,
device_map
)
torch_device
=
"cuda:0"
if
torch_device
==
"cuda"
else
torch_device
if
self
.
args
.
use_cuda_graph
:
torch
.
cuda
.
set_device
(
torch_device
)
if
warm_uped
and
self
.
args
.
use_cuda_graph
:
if
not
hasattr
(
self
,
"cuda_graph_runner"
):
self
.
cuda_graph_runner
=
CUDAGraphRunner
()
self
.
cuda_graph_runner
.
capture
(
...
...
@@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface):
else
:
logits
=
self
.
model
(
self
.
current_ids
,
return_dict
=
False
)[
0
]
logits
=
logits
[
0
,
-
1
,
:]
warm_uped
=
True
return
self
.
logits_to_token
(
logits
)
...
...
@@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface):
if
not
(
type
(
self
)
is
TransformersInterface
):
input_ids
=
input_ids
.
to
(
"cpu"
)
inputs_embeds
=
self
.
model
.
model
.
embed_tokens
(
input_ids
).
to
(
device
)
torch
.
cuda
.
set_device
(
device
)
if
self
.
use_static_cache
:
logits
=
self
.
model
(
inputs_embeds
=
inputs_embeds
,
...
...
ktransformers/server/main.py
View file @
ee24eb8d
...
...
@@ -106,9 +106,6 @@ def custom_openapi(app):
def
main
():
cfg
=
Config
()
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
cfg
.
use_cuda_graph
=
False
arg_parser
=
ArgumentParser
(
cfg
)
# 初始化消息
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment