Commit 5fe03549 authored by zhuwenwen's avatar zhuwenwen
Browse files

[perf] add VLLM_USE_FUSED_FILL_RMS_CAT to use lightop for dpsk mtp fill + rms*2 + cat

parent b8c7ba0a
...@@ -288,6 +288,7 @@ if TYPE_CHECKING: ...@@ -288,6 +288,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
...@@ -1819,6 +1820,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1819,6 +1820,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")),
# W8A8 GEMM backend selection for vLLM quantized models. # W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1 # lightop/triton: 1
# cutlass: 2 (will remove in the future) # cutlass: 2 (will remove in the future)
......
...@@ -192,6 +192,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -192,6 +192,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
# if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: # if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
# if not envs.is_set("USE_FUSED_RMS_QUANT"): # if not envs.is_set("USE_FUSED_RMS_QUANT"):
# os.environ['USE_FUSED_RMS_QUANT'] = '1' # os.environ['USE_FUSED_RMS_QUANT'] = '1'
...@@ -224,6 +226,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -224,6 +226,8 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_FUSED_FILL_RMS_CAT"):
os.environ['VLLM_USE_FUSED_FILL_RMS_CAT'] = '1'
# if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}: # if model_config.quantization in {"slimquant_w4a8", "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin", "compressed-tensors"}:
# if not envs.is_set("USE_FUSED_RMS_QUANT"): # if not envs.is_set("USE_FUSED_RMS_QUANT"):
# os.environ['USE_FUSED_RMS_QUANT'] = '1' # os.environ['USE_FUSED_RMS_QUANT'] = '1'
......
...@@ -39,6 +39,7 @@ from .deepseek_v2 import ( ...@@ -39,6 +39,7 @@ from .deepseek_v2 import (
from .utils import maybe_prefix from .utils import maybe_prefix
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -109,6 +110,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -109,6 +110,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
assert inputs_embeds is not None assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP # masking inputs at position 0, as not needed by MTP
if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
hidden_states = self.eh_proj(hidden_states_fuse)
else:
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds) inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states) previous_hidden_states = self.hnorm(previous_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment