Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0a409bd4
Unverified
Commit
0a409bd4
authored
Jul 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 27, 2024
Browse files
Fix return_log_probs with cuda graph (#775)
parent
e4db4e5b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
40 deletions
+62
-40
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+38
-25
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+24
-15
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
0a409bd4
"""Logits processing."""
import
dataclasses
from
typing
import
List
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
nn
...
...
@@ -34,11 +34,11 @@ class LogitProcessorOutput:
@
dataclasses
.
dataclass
class
LogitsMetadata
:
forward_mode
:
ForwardMode
return_logprob
:
bool
return_logprob
:
bool
=
False
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
top_logprobs_nums
:
List
[
int
]
=
None
extend_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
extend_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
top_logprobs_nums
:
Optional
[
List
[
int
]
]
=
None
@
classmethod
def
from_input_metadata
(
cls
,
input_metadata
:
InputMetadata
):
...
...
@@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module):
return
normalized_prompt_logprobs
def
_get_top_logprobs
(
self
,
all_logprobs
,
logits_metadata
:
LogitsMetadata
):
@
staticmethod
def
get_top_logprobs
(
all_logprobs
,
logits_metadata
:
LogitsMetadata
):
# TODO: vectorize the code below
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
decode_top_logprobs
=
[]
...
...
@@ -156,36 +157,48 @@ class LogitsProcessor(nn.Module):
else
:
# When logprob is requested, compute the logits for all tokens.
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
all_logits
=
last_logits
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
all_logprobs
=
all_logits
del
all_logits
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
logits_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
all_logprobs
,
logits_metadata
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
logits_metadata
.
top_logprobs_nums
)
else
:
prefill_top_logprobs
=
decode_top_logprobs
=
None
if
return_top_logprob
:
decode_top_logprobs
=
self
.
get_top_logprobs
(
last_logprobs
,
logits_metadata
)[
1
]
else
:
decode_top_logprobs
=
None
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
all
_logprobs
,
next_token_logprobs
=
last
_logprobs
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
decode_top_logprobs
,
)
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logprobs
=
all_logits
del
all_logits
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
logits_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
get_top_logprobs
(
all_logprobs
,
logits_metadata
)
else
:
prefill_top_logprobs
=
decode_top_logprobs
=
None
last_logprobs
=
all_logprobs
[
last_index
]
# Compute the logprobs and normalized logprobs for the prefill tokens.
...
...
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
0a409bd4
...
...
@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.layers.logits_processor
import
(
LogitProcessorOutput
,
LogitsMetadata
,
LogitsProcessor
,
)
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
...
...
@@ -185,7 +189,6 @@ class CudaGraphRunner:
def
replay
(
self
,
batch
:
Batch
):
assert
batch
.
out_cache_loc
is
not
None
assert
not
batch
.
return_logprob
raw_bs
=
len
(
batch
.
reqs
)
# Pad
...
...
@@ -218,23 +221,29 @@ class CudaGraphRunner:
output
=
self
.
output_buffers
[
bs
]
# Unpad
if
bs
==
raw_bs
:
return
output
else
:
if
bs
!=
raw_bs
:
output
=
LogitProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
raw_bs
],
next_token_logprobs
=
(
output
.
next_token_logprobs
[:
raw_bs
]
if
output
.
next_token_logprobs
is
not
None
else
None
),
next_token_logprobs
=
None
,
normalized_prompt_logprobs
=
None
,
prefill_token_logprobs
=
None
,
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
(
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
),
decode_top_logprobs
=
None
,
)
# Extract logprobs
if
batch
.
return_logprob
:
output
.
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
output
.
next_token_logits
,
dim
=-
1
)
return_top_logprob
=
any
(
x
>
0
for
x
in
batch
.
top_logprobs_nums
)
if
return_top_logprob
:
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
)
output
.
decode_top_logprobs
=
LogitsProcessor
.
get_top_logprobs
(
output
.
next_token_logprobs
,
logits_metadata
)[
1
]
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment