Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
45473d4b
Unverified
Commit
45473d4b
authored
Oct 04, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 04, 2024
Browse files
Make input_ids a torch.Tensor (#1568)
parent
114bbc86
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
7 deletions
+11
-7
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+9
-6
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
45473d4b
...
...
@@ -514,9 +514,10 @@ class ScheduleBatch:
pt
+=
req
.
extend_input_len
# Set fields
self
.
input_ids
=
sum
(
input_ids
,
[])
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
device
=
"cuda"
)
with
out_cache_loc
.
device
:
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
...
...
@@ -536,7 +537,7 @@ class ScheduleBatch:
req
.
fill_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
extend_input_len
=
1
input_ids
=
self
.
input_ids
+
running_batch
.
input_ids
input_ids
=
torch
.
cat
([
self
.
input_ids
,
running_batch
.
input_ids
])
out_cache_loc
=
torch
.
cat
([
self
.
out_cache_loc
,
running_batch
.
out_cache_loc
])
extend_num_tokens
=
self
.
extend_num_tokens
+
running_bs
...
...
@@ -722,7 +723,9 @@ class ScheduleBatch:
for
r
in
self
.
reqs
]
self
.
input_ids
=
input_ids
self
.
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
seq_lens
.
device
)
self
.
seq_lens
.
add_
(
1
)
# Alloc mem
...
...
@@ -824,7 +827,7 @@ class ModelWorkerBatch:
# The forward mode
forward_mode
:
ForwardMode
# The input ids
input_ids
:
List
[
int
]
input_ids
:
torch
.
Tensor
# The indices of requests in the req_to_token_pool
req_pool_indices
:
torch
.
Tensor
# The sequence length
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
45473d4b
...
...
@@ -30,6 +30,7 @@ class ReqToTokenPool:
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
):
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
device
=
device
self
.
free_slots
=
list
(
range
(
size
))
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
45473d4b
...
...
@@ -123,7 +123,7 @@ class ForwardBatch:
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
batch_size
=
len
(
batch
.
seq_lens
),
input_ids
=
torch
.
tensor
(
batch
.
input_ids
,
dtype
=
torch
.
int32
,
device
=
device
),
input_ids
=
batch
.
input_ids
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
out_cache_loc
=
batch
.
out_cache_loc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment