Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
fa03ea48
Commit
fa03ea48
authored
Mar 01, 2025
by
Atream
Browse files
Merge branch 'main' into feat-chunk-prefill-flashinfer
parents
f35e8d41
511958d4
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
146 deletions
+10
-146
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+4
-1
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+6
-0
test_prompt.txt
test_prompt.txt
+0
-145
No files found.
ktransformers/server/backend/interfaces/ktransformers.py
View file @
fa03ea48
...
@@ -129,8 +129,11 @@ class KTransformersInterface(TransformersInterface):
...
@@ -129,8 +129,11 @@ class KTransformersInterface(TransformersInterface):
@
torch
.
no_grad
@
torch
.
no_grad
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]):
def
prefill
(
self
,
input_ids
:
torch
.
Tensor
,
is_new
:
bool
,
temperature
:
Optional
[
float
],
top_p
:
Optional
[
float
]):
input_ids_length
=
input_ids
.
shape
[
-
1
]
input_ids_length
=
input_ids
.
shape
[
-
1
]
if
(
input_ids_length
>=
self
.
args
.
cache_lens
):
logger
.
warning
(
f
"input_ids_length
{
input_ids_length
}
> cache_lens
{
self
.
args
.
cache_lens
}
"
)
self
.
seq_length
=
input_ids_length
return
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
logger
.
debug
(
f
"input_ids:
{
input_ids
.
shape
}
"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
self
.
device_map
.
get
(
"blk.0.self_attn"
,
{}).
get
(
"generate_device"
,
"cuda:0"
)
device
=
"cuda:0"
if
device
==
"cuda"
else
device
device
=
"cuda:0"
if
device
==
"cuda"
else
device
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
fa03ea48
...
@@ -328,6 +328,12 @@ class TransformersInterface(BackendInterfaceBase):
...
@@ -328,6 +328,12 @@ class TransformersInterface(BackendInterfaceBase):
@
torch
.
no_grad
@
torch
.
no_grad
def
generate
(
self
):
def
generate
(
self
):
self
.
args
.
max_new_tokens
=
min
(
self
.
args
.
max_new_tokens
,
self
.
args
.
cache_lens
-
self
.
seq_length
)
if
(
self
.
args
.
max_new_tokens
<=
0
):
logger
.
warning
(
"max_new_tokens is less than 0"
)
yield
self
.
streamer
.
end
()
return
logger
.
info
(
f
"max_new_tokens:
{
self
.
args
.
max_new_tokens
}
"
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
self
.
profiler
.
set_counter
(
"decode"
,
0
)
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
for
i
in
range
(
1
,
self
.
args
.
max_new_tokens
):
...
...
test_prompt.txt
deleted
100644 → 0
View file @
f35e8d41
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment