Commit 81175fe4 authored by 王敏's avatar 王敏
Browse files

[fix]解决v1 deepseek cudagraph模式显存占用增长

parent 751592a6
...@@ -11,6 +11,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator ...@@ -11,6 +11,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.forward_context import get_profilling
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -169,7 +170,7 @@ def _support_torch_compile( ...@@ -169,7 +170,7 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling(): if self.do_not_compile or torch.compiler.is_compiling() or get_profilling():
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
# the first compilation needs to have dynamic shapes marked # the first compilation needs to have dynamic shapes marked
......
...@@ -196,3 +196,16 @@ def set_forward_context( ...@@ -196,3 +196,16 @@ def set_forward_context(
_forward_context = prev_context _forward_context = prev_context
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
set_tbo_forward_context(_forward_context) set_tbo_forward_context(_forward_context)
_profiling: bool = False
@contextmanager
def set_profilling(profiling):
global _profiling
_profiling = profiling
def get_profilling() -> bool:
global _profiling
return _profiling
\ No newline at end of file
...@@ -29,7 +29,7 @@ from vllm.distributed.parallel_state import ( ...@@ -29,7 +29,7 @@ from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank, get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import (DPMetadata, get_forward_context, from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context) set_forward_context, set_profilling)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -2087,7 +2087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2087,7 +2087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.speculative_config and self.speculative_config.use_eagle(): if self.speculative_config and self.speculative_config.use_eagle() and not is_profile:
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens, attn_metadata) self.drafter.dummy_run(num_tokens, attn_metadata)
...@@ -2222,6 +2222,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2222,6 +2222,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return pooler_output return pooler_output
def profile_run(self) -> None: def profile_run(self) -> None:
# set profiling flag to avoid torch compile
set_profilling(True)
self._sync_device()
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. # TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
...@@ -2305,6 +2309,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2305,6 +2309,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
del hidden_states, output del hidden_states, output
self.encoder_cache.clear() self.encoder_cache.clear()
gc.collect() gc.collect()
set_profilling(False)
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_cuda_graph: if not self.use_cuda_graph:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment