Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cb099d20
Unverified
Commit
cb099d20
authored
Aug 03, 2025
by
Cheng Wan
Committed by
GitHub
Aug 03, 2025
Browse files
[CUDA Graph] save cuda graph memory by using next_token_logits_buffer (#8579)
parent
7a913301
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
1 deletion
+36
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+9
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+7
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-0
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+18
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-0
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
cb099d20
...
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
...
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
class
LogitsMetadata
:
class
LogitsMetadata
:
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
NULL
capture_hidden_mode
:
CaptureHiddenMode
=
CaptureHiddenMode
.
NULL
next_token_logits_buffer
:
Optional
[
torch
.
Tensor
]
=
None
extend_return_logprob
:
bool
=
False
extend_return_logprob
:
bool
=
False
extend_return_top_logprob
:
bool
=
False
extend_return_top_logprob
:
bool
=
False
...
@@ -148,6 +149,7 @@ class LogitsMetadata:
...
@@ -148,6 +149,7 @@ class LogitsMetadata:
return
cls
(
return
cls
(
forward_mode
=
forward_batch
.
forward_mode
,
forward_mode
=
forward_batch
.
forward_mode
,
capture_hidden_mode
=
forward_batch
.
capture_hidden_mode
,
capture_hidden_mode
=
forward_batch
.
capture_hidden_mode
,
next_token_logits_buffer
=
forward_batch
.
next_token_logits_buffer
,
extend_return_logprob
=
extend_return_logprob
,
extend_return_logprob
=
extend_return_logprob
,
extend_return_top_logprob
=
extend_return_top_logprob
,
extend_return_top_logprob
=
extend_return_top_logprob
,
extend_token_ids_logprob
=
extend_token_ids_logprob
,
extend_token_ids_logprob
=
extend_token_ids_logprob
,
...
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
...
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
)
)
dp_scatter
(
logits
,
global_logits
,
logits_metadata
)
dp_scatter
(
logits
,
global_logits
,
logits_metadata
)
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
logits_metadata
.
next_token_logits_buffer
is
not
None
:
logits_buffer
=
logits_metadata
.
next_token_logits_buffer
assert
logits_buffer
.
dtype
==
torch
.
float
logits_buffer
.
copy_
(
logits
[:,
:
self
.
config
.
vocab_size
])
logits
=
logits_buffer
else
:
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
if
self
.
final_logit_softcapping
:
fused_softcap
(
logits
,
self
.
final_logit_softcapping
)
fused_softcap
(
logits
,
self
.
final_logit_softcapping
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
cb099d20
...
@@ -375,6 +375,11 @@ class CudaGraphRunner:
...
@@ -375,6 +375,11 @@ class CudaGraphRunner:
dtype
=
torch
.
bool
,
dtype
=
torch
.
bool
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
self
.
next_token_logits_buffer
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
vocab_size
),
dtype
=
torch
.
float
,
device
=
"cuda"
,
)
# Capture
# Capture
try
:
try
:
...
@@ -520,6 +525,7 @@ class CudaGraphRunner:
...
@@ -520,6 +525,7 @@ class CudaGraphRunner:
else
:
else
:
encoder_lens
=
None
encoder_lens
=
None
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
next_token_logits_buffer
=
self
.
next_token_logits_buffer
[:
num_tokens
]
self
.
num_token_non_padded
[...]
=
num_tokens
self
.
num_token_non_padded
[...]
=
num_tokens
# pipeline parallelism
# pipeline parallelism
...
@@ -582,6 +588,7 @@ class CudaGraphRunner:
...
@@ -582,6 +588,7 @@ class CudaGraphRunner:
input_ids
=
input_ids
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
cb099d20
...
@@ -189,6 +189,7 @@ class ForwardBatch:
...
@@ -189,6 +189,7 @@ class ForwardBatch:
token_ids_logprobs
:
Optional
[
List
[
List
[
int
]]]
=
None
token_ids_logprobs
:
Optional
[
List
[
List
[
int
]]]
=
None
# For logits and logprobs post processing
# For logits and logprobs post processing
next_token_logits_buffer
:
torch
.
Tensor
=
None
temp_scaled_logprobs
:
bool
=
False
temp_scaled_logprobs
:
bool
=
False
temperature
:
torch
.
Tensor
=
None
temperature
:
torch
.
Tensor
=
None
top_p_normalized_logprobs
:
bool
=
False
top_p_normalized_logprobs
:
bool
=
False
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
cb099d20
...
@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
global_num_tokens_for_logprob_gpu
=
None
self
.
gathered_buffer
=
None
self
.
gathered_buffer
=
None
if
hasattr
(
self
.
model_runner
.
model_config
.
hf_config
,
"draft_vocab_size"
):
# llama_eagle
vocab_size
=
self
.
model_runner
.
model_config
.
hf_config
.
draft_vocab_size
elif
hasattr
(
self
.
model_runner
.
model_config
.
hf_config
,
"hot_vocab_size"
):
# llama_eagle3
vocab_size
=
self
.
model_runner
.
model_config
.
hf_config
.
hot_vocab_size
else
:
vocab_size
=
self
.
model_runner
.
model_config
.
vocab_size
self
.
next_token_logits_buffer
=
torch
.
zeros
(
(
self
.
max_bs
,
vocab_size
),
dtype
=
torch
.
float
,
)
# Capture
# Capture
try
:
try
:
with
model_capture_mode
():
with
model_capture_mode
():
...
@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
next_token_logits_buffer
=
self
.
next_token_logits_buffer
[:
bs
]
if
self
.
require_mlp_tp_gather
:
if
self
.
require_mlp_tp_gather
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
...
@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids
=
input_ids
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
next_token_logits_buffer
=
next_token_logits_buffer
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
...
...
python/sglang/srt/two_batch_overlap.py
View file @
cb099d20
...
@@ -564,6 +564,7 @@ class TboForwardBatchPreparer:
...
@@ -564,6 +564,7 @@ class TboForwardBatchPreparer:
mm_inputs
=
None
,
mm_inputs
=
None
,
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
token_ids_logprobs
=
None
,
token_ids_logprobs
=
None
,
next_token_logits_buffer
=
None
,
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment