Commit 6a8fd297 authored by zhuwenwen's avatar zhuwenwen
Browse files

set VLLM_USE_LIGHTOP=0 for dpsk-v3

add VLLM_USE_PD_SPLIT to split prefill and decode
replace triton_ of rms and act_and_mul
parent 31201280
......@@ -234,6 +234,7 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False
def get_default_cache_root():
......@@ -1632,6 +1633,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -80,6 +80,9 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x)
else:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
......
......@@ -191,6 +191,9 @@ class RMSNorm(CustomOp):
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual)
else:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
......
......@@ -194,8 +194,8 @@ def _get_model_architecture(
os.environ['LM_NN'] = '1'
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP"):
# os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
......
......@@ -558,6 +558,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device,
)
self.block_table = block_table
self.use_spec_decode = False
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
......@@ -651,13 +659,31 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len <= self.reorder_batch_threshold # decode only
# assert m.num_reqs <= (m.num_actual_tokens *
# self.reorder_batch_threshold), \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.max_query_len <= self.reorder_batch_threshold # decode only
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m)
def build(self,
......@@ -673,8 +699,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table = self.block_table
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_tokens].copy_(
block_table.slot_mapping_cpu[:num_tokens],
non_blocking=True)
block_table.slot_mapping[num_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_tokens]
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
......@@ -840,8 +873,57 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
- prefill_query_start_loc[:-1]
prefill_metadata.cudnn_workspace = self.cudnn_workspace
# TODO @ wangming
decode_metadata = None
if num_decodes > 0:
# if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
# query_lens = self.num_scheduled_tokens_np[:num_decodes]
# cu_num_blocks = np.cumsum(query_lens)
# virtual_batches = cu_num_blocks[-1]
# block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
# arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# rarange = np.repeat(query_lens, query_lens) - arange - 1
# repeats = torch.from_numpy(query_lens).pin_memory().to(
# block_table_tensor.device, non_blocking=True).contiguous()
# decode_block_table_tensor = torch.repeat_interleave(
# block_table_tensor[:self._num_decodes, ...],
# repeats, dim=0).contiguous()
# decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
# seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
# seq_lens.device, non_blocking=True).contiguous()
# decode_seq_lens = decode_seq_lens - seq_lens_minus
# if self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=decode_block_table_tensor,
# seq_lens=decode_seq_lens,
# )
# else:
# self._num_decode_tokens = num_decodes
# if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
......
......@@ -9,6 +9,7 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
......@@ -1045,7 +1046,7 @@ class Scheduler(SchedulerInterface):
return scheduler_output
def schedule(self) -> SchedulerOutput:
if self.num_spec_tokens > 0:
if self.num_spec_tokens > 0 or envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd()
else:
return self.schedule_default()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment