Unverified Commit cb099d20 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[CUDA Graph] save cuda graph memory by using next_token_logits_buffer (#8579)

parent 7a913301
...@@ -83,6 +83,7 @@ class LogitsProcessorOutput: ...@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
class LogitsMetadata: class LogitsMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
next_token_logits_buffer: Optional[torch.Tensor] = None
extend_return_logprob: bool = False extend_return_logprob: bool = False
extend_return_top_logprob: bool = False extend_return_top_logprob: bool = False
...@@ -148,6 +149,7 @@ class LogitsMetadata: ...@@ -148,6 +149,7 @@ class LogitsMetadata:
return cls( return cls(
forward_mode=forward_batch.forward_mode, forward_mode=forward_batch.forward_mode,
capture_hidden_mode=forward_batch.capture_hidden_mode, capture_hidden_mode=forward_batch.capture_hidden_mode,
next_token_logits_buffer=forward_batch.next_token_logits_buffer,
extend_return_logprob=extend_return_logprob, extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob, extend_return_top_logprob=extend_return_top_logprob,
extend_token_ids_logprob=extend_token_ids_logprob, extend_token_ids_logprob=extend_token_ids_logprob,
...@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module): ...@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
) )
dp_scatter(logits, global_logits, logits_metadata) dp_scatter(logits, global_logits, logits_metadata)
logits = logits[:, : self.config.vocab_size].float() if logits_metadata.next_token_logits_buffer is not None:
logits_buffer = logits_metadata.next_token_logits_buffer
assert logits_buffer.dtype == torch.float
logits_buffer.copy_(logits[:, : self.config.vocab_size])
logits = logits_buffer
else:
logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping: if self.final_logit_softcapping:
fused_softcap(logits, self.final_logit_softcapping) fused_softcap(logits, self.final_logit_softcapping)
......
...@@ -375,6 +375,11 @@ class CudaGraphRunner: ...@@ -375,6 +375,11 @@ class CudaGraphRunner:
dtype=torch.bool, dtype=torch.bool,
device="cuda", device="cuda",
) )
self.next_token_logits_buffer = torch.zeros(
(self.max_num_token, self.model_runner.model_config.vocab_size),
dtype=torch.float,
device="cuda",
)
# Capture # Capture
try: try:
...@@ -520,6 +525,7 @@ class CudaGraphRunner: ...@@ -520,6 +525,7 @@ class CudaGraphRunner:
else: else:
encoder_lens = None encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
self.num_token_non_padded[...] = num_tokens self.num_token_non_padded[...] = num_tokens
# pipeline parallelism # pipeline parallelism
...@@ -582,6 +588,7 @@ class CudaGraphRunner: ...@@ -582,6 +588,7 @@ class CudaGraphRunner:
input_ids=input_ids, input_ids=input_ids,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend, attn_backend=self.model_runner.attn_backend,
......
...@@ -189,6 +189,7 @@ class ForwardBatch: ...@@ -189,6 +189,7 @@ class ForwardBatch:
token_ids_logprobs: Optional[List[List[int]]] = None token_ids_logprobs: Optional[List[List[int]]] = None
# For logits and logprobs post processing # For logits and logprobs post processing
next_token_logits_buffer: torch.Tensor = None
temp_scaled_logprobs: bool = False temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False top_p_normalized_logprobs: bool = False
......
...@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = None self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None self.gathered_buffer = None
if hasattr(
self.model_runner.model_config.hf_config, "draft_vocab_size"
): # llama_eagle
vocab_size = self.model_runner.model_config.hf_config.draft_vocab_size
elif hasattr(
self.model_runner.model_config.hf_config, "hot_vocab_size"
): # llama_eagle3
vocab_size = self.model_runner.model_config.hf_config.hot_vocab_size
else:
vocab_size = self.model_runner.model_config.vocab_size
self.next_token_logits_buffer = torch.zeros(
(self.max_bs, vocab_size),
dtype=torch.float,
)
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
...@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
out_cache_loc = self.out_cache_loc[:num_tokens] out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens] hidden_states = self.hidden_states[:num_tokens]
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
...@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids=input_ids, input_ids=input_ids,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
......
...@@ -564,6 +564,7 @@ class TboForwardBatchPreparer: ...@@ -564,6 +564,7 @@ class TboForwardBatchPreparer:
mm_inputs=None, mm_inputs=None,
top_logprobs_nums=None, top_logprobs_nums=None,
token_ids_logprobs=None, token_ids_logprobs=None,
next_token_logits_buffer=None,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment