Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c9de3e16
Unverified
Commit
c9de3e16
authored
Apr 03, 2024
by
Qubitium
Committed by
GitHub
Apr 03, 2024
Browse files
Eliminate 2 gpu ops during sampling when logit_bias is zero (#338)
Co-authored-by:
hnyls2002
<
hnyls2002@gmail.com
>
parent
ed27a6b9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
8 deletions
+32
-8
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+32
-8
No files found.
python/sglang/srt/managers/router/infer_batch.py
View file @
c9de3e16
...
@@ -251,10 +251,14 @@ class Batch:
...
@@ -251,10 +251,14 @@ class Batch:
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
]
=
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]]
pt
+=
extend_lens
[
i
]
pt
+=
extend_lens
[
i
]
# Handle logit bias
# Handle logit bias
but only allocate when needed
logit_bias
=
torch
.
zeros
((
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
logit_bias
=
None
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
reqs
[
i
].
sampling_params
.
dtype
==
"int"
:
if
logit_bias
is
None
:
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
logit_bias
[
i
]
=
int_token_logit_bias
logit_bias
[
i
]
=
int_token_logit_bias
# Set fields
# Set fields
...
@@ -433,9 +437,12 @@ class Batch:
...
@@ -433,9 +437,12 @@ class Batch:
"presence_penalties"
,
"presence_penalties"
,
"logit_bias"
,
"logit_bias"
,
]:
]:
setattr
(
self
,
item
,
getattr
(
self
,
item
)[
new_indices
])
self_val
=
getattr
(
self
,
item
,
None
)
# logit_bias can be None
if
self_val
is
not
None
:
setattr
(
self
,
item
,
self_val
[
new_indices
])
def
merge
(
self
,
other
):
def
merge
(
self
,
other
:
"Batch"
):
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
req_pool_indices
=
torch
.
concat
(
self
.
req_pool_indices
=
torch
.
concat
(
...
@@ -456,16 +463,33 @@ class Batch:
...
@@ -456,16 +463,33 @@ class Batch:
"top_ks"
,
"top_ks"
,
"frequency_penalties"
,
"frequency_penalties"
,
"presence_penalties"
,
"presence_penalties"
,
"logit_bias"
,
]:
]:
setattr
(
self_val
=
getattr
(
self
,
item
,
None
)
self
,
item
,
torch
.
concat
([
getattr
(
self
,
item
),
getattr
(
other
,
item
)])
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
logits
.
div_
(
self
.
temperatures
)
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
logits
.
add_
(
self
.
logit_bias
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
reqs
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
reqs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment