Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
392f2863
"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d03240801f2ac2b4d1f49584c1c5628b98583f6a"
Unverified
Commit
392f2863
authored
Oct 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 18, 2024
Browse files
Add dtype for more operations (#1705)
parent
6d0fa73e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
4 deletions
+5
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+1
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
392f2863
...
@@ -537,8 +537,8 @@ class ScheduleBatch:
...
@@ -537,8 +537,8 @@ class ScheduleBatch:
# Set fields
# Set fields
with
out_cache_loc
.
device
:
with
out_cache_loc
.
device
:
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
tensor
(
sum
(
input_ids
,
[]),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
)
self
.
req_pool_indices
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
)
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
392f2863
...
@@ -145,8 +145,9 @@ class ForwardBatch:
...
@@ -145,8 +145,9 @@ class ForwardBatch:
],
],
axis
=
0
,
axis
=
0
,
),
),
dtype
=
torch
.
int64
,
device
=
device
,
device
=
device
,
)
.
to
(
torch
.
int64
)
)
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
image_inputs
=
batch
.
image_inputs
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
device
=
device
)
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
device
=
device
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
392f2863
...
@@ -57,7 +57,7 @@ class SamplingBatchInfo:
...
@@ -57,7 +57,7 @@ class SamplingBatchInfo:
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
)
)
top_ks
=
torch
.
tensor
(
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
32
)
)
min_ps
=
torch
.
tensor
(
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment