Unverified Commit d4c038da authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix]Fix capture fail bug for DeepSeek (#6275)

parent 55f6005f
...@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache): ...@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache):
self._create_buffers() self._create_buffers()
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.capture_mode = False
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if is_cuda else None self.alt_stream = self.device_module.Stream() if is_cuda else None
...@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache): ...@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache):
k_scale: Optional[float] = None, k_scale: Optional[float] = None,
v_scale: Optional[float] = None, v_scale: Optional[float] = None,
): ):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
layer_id = layer.layer_id layer_id = layer.layer_id
if cache_k.dtype != self.dtype: if cache_k.dtype != self.dtype:
if k_scale is not None: if k_scale is not None:
...@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache): ...@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache):
cache_k = cache_k.view(self.store_dtype) cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype)
if self.capture_mode and self.alt_stream is not None: if get_is_capture_mode() and self.alt_stream is not None:
# Overlap the copy of K and V cache for small batch size # Overlap the copy of K and V cache for small batch size
current_stream = self.device_module.current_stream() current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
......
...@@ -47,6 +47,13 @@ from sglang.srt.utils import ( ...@@ -47,6 +47,13 @@ from sglang.srt.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode = False
def get_is_capture_mode():
return is_capture_mode
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
...@@ -311,17 +318,12 @@ class CudaGraphRunner: ...@@ -311,17 +318,12 @@ class CudaGraphRunner:
@contextmanager @contextmanager
def model_capture_mode(self): def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"): global is_capture_mode
self.model_runner.model.capture_mode = True is_capture_mode = True
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = True
yield yield
if hasattr(self.model_runner.model, "capture_mode"): is_capture_mode = False
self.model_runner.model.capture_mode = False
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
...@@ -612,6 +614,7 @@ class CudaGraphRunner: ...@@ -612,6 +614,7 @@ class CudaGraphRunner:
# Replay # Replay
self.graphs[self.bs].replay() self.graphs[self.bs].replay()
output = self.output_buffers[self.bs] output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput): if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput( return LogitsProcessorOutput(
......
...@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
...@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
k_nope = latent_cache[..., : self.kv_lora_rank] k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm # overlap qk norm
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing(): if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
......
...@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
prefix="multi_modal_projector", prefix="multi_modal_projector",
) )
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pixel_values = torch.cat( pixel_values = torch.cat(
...@@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = ( batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
self._batch_image_inputs(forward_batch) self._batch_image_inputs(forward_batch)
) )
...@@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
cross_attention_mask = None cross_attention_mask = None
cross_attention_states = None cross_attention_states = None
if self.capture_mode: if get_is_capture_mode():
# NOTE: when doing cuda graph capture, we do not want to skip cross attention # NOTE: when doing cuda graph capture, we do not want to skip cross attention
# Make is a constant value to avoid cuda graph capture issue # Make is a constant value to avoid cuda graph capture issue
skip_cross_attention = False skip_cross_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment