Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ed27a6b9
Unverified
Commit
ed27a6b9
authored
Apr 03, 2024
by
Liangsheng Yin
Committed by
GitHub
Apr 03, 2024
Browse files
Revert "Eliminate 2 gpu ops during sampling when logit_bias is zero" (#345)
parent
463c6632
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
32 deletions
+8
-32
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+8
-32
No files found.
python/sglang/srt/managers/router/infer_batch.py
View file @
ed27a6b9
...
...
@@ -251,14 +251,10 @@ class Batch:
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
pt
+=
extend_lens
[
i
]
# Handle logit bias
but only allocate when needed
logit_bias
=
None
# Handle logit bias
logit_bias
=
torch
.
zeros
((
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
for
i
in
range
(
bs
):
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
logit_bias
is
None
:
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
logit_bias
[
i
]
=
int_token_logit_bias
# Set fields
...
...
@@ -437,12 +433,9 @@ class Batch:
"presence_penalties"
,
"logit_bias"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
# logit_bias can be None
if
self_val
is
not
None
:
setattr
(
self
,
item
,
self_val
[
new_indices
])
setattr
(
self
,
item
,
getattr
(
self
,
item
)[
new_indices
])
def
merge
(
self
,
other
:
"Batch"
):
def
merge
(
self
,
other
):
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
req_pool_indices
=
torch
.
concat
(
...
...
@@ -463,34 +456,17 @@ class Batch:
"top_ks"
,
"frequency_penalties"
,
"presence_penalties"
,
"logit_bias"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
setattr
(
self
,
item
,
torch
.
concat
([
getattr
(
self
,
item
),
getattr
(
other
,
item
)])
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
logits
.
add_
(
self
.
logit_bias
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
reqs
)
if
has_regex
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment