Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fbb4754c
Unverified
Commit
fbb4754c
authored
Sep 10, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 10, 2024
Browse files
Fix vocab mask update bug (#1376)
parent
6c7cb903
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
22 deletions
+29
-22
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+27
-19
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
fbb4754c
...
...
@@ -652,8 +652,6 @@ class ScheduleBatch:
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
self
.
sampling_info
.
update_regex_vocab_mask
(
self
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
fbb4754c
...
...
@@ -195,7 +195,8 @@ class InputMetadata:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
ret
.
sampling_info
.
prepare_penalties
()
ret
.
sampling_info
.
update_penalties
()
ret
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
ret
.
compute_positions
(
batch
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
fbb4754c
...
...
@@ -34,6 +34,9 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
def
__len__
(
self
):
return
len
(
self
.
temperatures
)
def
can_run_in_cuda_graph
(
self
):
# Vocab bias and min_ps are not supported in CUDA graph
return
(
...
...
@@ -118,11 +121,9 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
ret
.
update_regex_vocab_mask
(
batch
)
return
ret
def
prepar
e_penalties
(
self
):
def
updat
e_penalties
(
self
):
self
.
scaling_penalties
=
None
self
.
linear_penalties
=
None
...
...
@@ -174,6 +175,26 @@ class SamplingBatchInfo:
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
@
staticmethod
def
merge_bias_tensor
(
lhs
:
torch
.
Tensor
,
rhs
:
torch
.
Tensor
,
bs1
:
int
,
bs2
:
int
,
default
:
int
=
0
):
# bias tensor can be None
if
lhs
is
not
None
or
rhs
is
not
None
:
shape
,
dtype
=
None
,
None
if
lhs
is
not
None
:
shape
,
dtype
=
lhs
.
shape
[
1
:],
lhs
.
dtype
else
:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
with
torch
.
dtype
(
dtype
):
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
"cuda"
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
"cuda"
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
return
None
def
merge
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
...
...
@@ -187,19 +208,6 @@ class SamplingBatchInfo:
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment