Commit 6a8fd297 authored by zhuwenwen's avatar zhuwenwen
Browse files

set VLLM_USE_LIGHTOP=0 for dpsk-v3

add VLLM_USE_PD_SPLIT to split prefill and decode
replace triton_ of rms and act_and_mul
parent 31201280
...@@ -234,6 +234,7 @@ if TYPE_CHECKING: ...@@ -234,6 +234,7 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1632,6 +1633,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1632,6 +1633,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -80,6 +80,9 @@ class SiluAndMul(CustomOp): ...@@ -80,6 +80,9 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x)
else:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
......
...@@ -191,6 +191,9 @@ class RMSNorm(CustomOp): ...@@ -191,6 +191,9 @@ class RMSNorm(CustomOp):
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual)
else:
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
......
...@@ -194,8 +194,8 @@ def _get_model_architecture( ...@@ -194,8 +194,8 @@ def _get_model_architecture(
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): # if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' # os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
......
...@@ -558,6 +558,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -558,6 +558,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device, device=device,
) )
self.block_table = block_table
self.use_spec_decode = False
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc qo_indptr = prefill.query_start_loc
...@@ -651,13 +659,31 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -651,13 +659,31 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA. Currently, only decode is supported for full cudagraphs with MLA.
""" """
m = common_attn_metadata m = common_attn_metadata
assert m.num_reqs <= (m.num_actual_tokens * # assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \ # self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \ # "MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq." # "Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len <= self.reorder_batch_threshold # decode only # assert m.max_query_len <= self.reorder_batch_threshold # decode only
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
return self.build(0, m) return self.build(0, m)
def build(self, def build(self,
...@@ -673,8 +699,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -673,8 +699,15 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because # function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels. # it blocks on all previous kernels.
device = self.device device = self.device
block_table = self.block_table
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_tokens].copy_(
block_table.slot_mapping_cpu[:num_tokens],
non_blocking=True)
block_table.slot_mapping[num_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_tokens]
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
...@@ -840,8 +873,57 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -840,8 +873,57 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
- prefill_query_start_loc[:-1] - prefill_query_start_loc[:-1]
prefill_metadata.cudnn_workspace = self.cudnn_workspace prefill_metadata.cudnn_workspace = self.cudnn_workspace
# TODO @ wangming
decode_metadata = None decode_metadata = None
if num_decodes > 0: # if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
# query_lens = self.num_scheduled_tokens_np[:num_decodes]
# cu_num_blocks = np.cumsum(query_lens)
# virtual_batches = cu_num_blocks[-1]
# block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
# arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# rarange = np.repeat(query_lens, query_lens) - arange - 1
# repeats = torch.from_numpy(query_lens).pin_memory().to(
# block_table_tensor.device, non_blocking=True).contiguous()
# decode_block_table_tensor = torch.repeat_interleave(
# block_table_tensor[:self._num_decodes, ...],
# repeats, dim=0).contiguous()
# decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
# seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
# seq_lens.device, non_blocking=True).contiguous()
# decode_seq_lens = decode_seq_lens - seq_lens_minus
# if self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=decode_block_table_tensor,
# seq_lens=decode_seq_lens,
# )
# else:
# self._num_decode_tokens = num_decodes
# if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
decode_metadata = self._build_decode( decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...], block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes], seq_lens_cpu=seq_lens_cpu[:num_decodes],
......
...@@ -9,6 +9,7 @@ from collections import defaultdict ...@@ -9,6 +9,7 @@ from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
...@@ -1045,7 +1046,7 @@ class Scheduler(SchedulerInterface): ...@@ -1045,7 +1046,7 @@ class Scheduler(SchedulerInterface):
return scheduler_output return scheduler_output
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if self.num_spec_tokens > 0: if self.num_spec_tokens > 0 or envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd() return self.schedule_split_pd()
else: else:
return self.schedule_default() return self.schedule_default()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment