Unverified Commit ce9b1d76 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[MRV2] Skip hidden states allocation for PW CUDA graphs (#37818)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent e74c17e1
...@@ -263,6 +263,7 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -263,6 +263,7 @@ class ModelCudaGraphManager(CudaGraphManager):
decode_query_len: int, decode_query_len: int,
): ):
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len) super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
# Used for FULL CUDA graphs. PW CUDA graphs do not use these.
self.hidden_states: torch.Tensor | None = None self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = [] self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
...@@ -326,6 +327,12 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -326,6 +327,12 @@ class ModelCudaGraphManager(CudaGraphManager):
**model_state.prepare_dummy_inputs(num_reqs, num_tokens), **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
} }
model_output = model(**model_inputs) model_output = model(**model_inputs)
if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states.
return None
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment