Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c9de3e16
Unverified
Commit
c9de3e16
authored
Apr 03, 2024
by
Qubitium
Committed by
GitHub
Apr 03, 2024
Browse files
Eliminate 2 gpu ops during sampling when logit_bias is zero (#338)
Co-authored-by:
hnyls2002
<
hnyls2002@gmail.com
>
parent
ed27a6b9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
8 deletions
+32
-8
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+32
-8
No files found.
python/sglang/srt/managers/router/infer_batch.py
View file @
c9de3e16
...
...
@@ -251,10 +251,14 @@ class Batch:
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
pt
+=
extend_lens
[
i
]
# Handle logit bias
logit_bias
=
torch
.
zeros
((
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
# Handle logit bias
but only allocate when needed
logit_bias
=
None
for
i
in
range
(
bs
):
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
logit_bias
is
None
:
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
logit_bias
[
i
]
=
int_token_logit_bias
# Set fields
...
...
@@ -433,9 +437,12 @@ class Batch:
"presence_penalties"
,
"logit_bias"
,
]:
setattr
(
self
,
item
,
getattr
(
self
,
item
)[
new_indices
])
self_val
=
getattr
(
self
,
item
,
None
)
# logit_bias can be None
if
self_val
is
not
None
:
setattr
(
self
,
item
,
self_val
[
new_indices
])
def
merge
(
self
,
other
):
def
merge
(
self
,
other
:
"Batch"
):
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
req_pool_indices
=
torch
.
concat
(
...
...
@@ -456,16 +463,33 @@ class Batch:
"top_ks"
,
"frequency_penalties"
,
"presence_penalties"
,
"logit_bias"
,
]:
setattr
(
self
,
item
,
torch
.
concat
([
getattr
(
self
,
item
),
getattr
(
other
,
item
)])
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
reqs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment