Unverified Commit d4c038da authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix]Fix capture fail bug for DeepSeek (#6275)

parent 55f6005f
......@@ -266,7 +266,6 @@ class MHATokenToKVPool(KVCache):
self._create_buffers()
self.layer_transfer_counter = None
self.capture_mode = False
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if is_cuda else None
......@@ -385,6 +384,8 @@ class MHATokenToKVPool(KVCache):
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
......@@ -398,7 +399,7 @@ class MHATokenToKVPool(KVCache):
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if self.capture_mode and self.alt_stream is not None:
if get_is_capture_mode() and self.alt_stream is not None:
# Overlap the copy of K and V cache for small batch size
current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream)
......
......@@ -47,6 +47,13 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode = False
def get_is_capture_mode():
return is_capture_mode
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
......@@ -311,17 +318,12 @@ class CudaGraphRunner:
@contextmanager
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = True
global is_capture_mode
is_capture_mode = True
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = False
is_capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
......@@ -612,6 +614,7 @@ class CudaGraphRunner:
# Replay
self.graphs[self.bs].replay()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput(
......
......@@ -754,6 +754,8 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
......@@ -761,7 +763,7 @@ class DeepseekV2AttentionMLA(nn.Module):
k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
......
......@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
prefix="multi_modal_projector",
)
self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pixel_values = torch.cat(
......@@ -969,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> Union[Tuple, CausalLMOutputWithPast]:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
self._batch_image_inputs(forward_batch)
)
......@@ -977,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
cross_attention_mask = None
cross_attention_states = None
if self.capture_mode:
if get_is_capture_mode():
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
# Make is a constant value to avoid cuda graph capture issue
skip_cross_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment