Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fbb4754c
"torch_cluster/functions/serial.py" did not exist on "65846a615daf0e78494026dbc5570c613fb14902"
Unverified
Commit
fbb4754c
authored
Sep 10, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 10, 2024
Browse files
Fix vocab mask update bug (#1376)
parent
6c7cb903
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
22 deletions
+29
-22
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+27
-19
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
fbb4754c
...
...
@@ -652,8 +652,6 @@ class ScheduleBatch:
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
self
.
sampling_info
.
update_regex_vocab_mask
(
self
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
fbb4754c
...
...
@@ -195,7 +195,8 @@ class InputMetadata:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
ret
.
sampling_info
.
prepare_penalties
()
ret
.
sampling_info
.
update_penalties
()
ret
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
ret
.
compute_positions
(
batch
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
fbb4754c
...
...
@@ -34,6 +34,9 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
__len__
(
self
):
return
len
(
self
.
temperatures
)
def
can_run_in_cuda_graph
(
self
):
# Vocab bias and min_ps are not supported in CUDA graph
return
(
...
...
@@ -118,11 +121,9 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
ret
.
update_regex_vocab_mask
(
batch
)
return
ret
def
prepar
e_penalties
(
self
):
def
updat
e_penalties
(
self
):
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
...
...
@@ -174,6 +175,26 @@ class SamplingBatchInfo:
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
@
staticmethod
def
merge_bias_tensor
(
lhs
:
torch
.
Tensor
,
rhs
:
torch
.
Tensor
,
bs1
:
int
,
bs2
:
int
,
default
:
int
=
0
):
# bias tensor can be None
if
lhs
is
not
None
or
rhs
is
not
None
:
shape
,
dtype
=
None
,
None
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
with
torch
.
dtype
(
dtype
):
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
"cuda"
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
"cuda"
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
return
None
def
merge
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
...
...
@@ -187,19 +208,6 @@ class SamplingBatchInfo:
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment