Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
392f2863
Unverified
Commit
392f2863
authored
Oct 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 18, 2024
Browse files
Add dtype for more operations (#1705)
parent
6d0fa73e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
4 deletions
+5
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+1
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
392f2863
...
@@ -537,8 +537,8 @@ class ScheduleBatch:
...
@@ -537,8 +537,8 @@ class ScheduleBatch:
# Set fields
# Set fields
with
out_cache_loc
.
device
:
with
out_cache_loc
.
device
:
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
392f2863
...
@@ -145,8 +145,9 @@ class ForwardBatch:
...
@@ -145,8 +145,9 @@ class ForwardBatch:
],
],
axis
=
0
,
axis
=
0
,
),
),
dtype
=
torch
.
int64
,
device
=
device
,
device
=
device
,
)
.
to
(
torch
.
int64
)
)
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
device
=
device
)
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
device
=
device
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
392f2863
...
@@ -57,7 +57,7 @@ class SamplingBatchInfo:
...
@@ -57,7 +57,7 @@ class SamplingBatchInfo:
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
top_ks
=
torch
.
tensor
(
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
32
)
)
min_ps
=
torch
.
tensor
(
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment