Commit df845734 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_FUSE_SILU_AND_MUL

parents d261a1e6 a921f34c
...@@ -186,6 +186,7 @@ if TYPE_CHECKING: ...@@ -186,6 +186,7 @@ if TYPE_CHECKING:
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_CAT_MLA: bool = False VLLM_USE_CAT_MLA: bool = False
VLLM_REJECT_SAMPLE_OPT: bool = False VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = True
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1206,6 +1207,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1206,6 +1207,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_REJECT_SAMPLE_OPT": "VLLM_REJECT_SAMPLE_OPT":
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in
("true", "1")), ("true", "1")),
# vLLM will use fused silu+mul kernel
"VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1875,6 +1875,10 @@ def fused_experts_impl( ...@@ -1875,6 +1875,10 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if activation == "silu": if activation == "silu":
if envs.VLLM_USE_FUSE_SILU_AND_MUL:
from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2)
else:
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
elif activation == "gelu": elif activation == "gelu":
......
...@@ -258,7 +258,8 @@ def get_model_architecture( ...@@ -258,7 +258,8 @@ def get_model_architecture(
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
...@@ -279,6 +280,8 @@ def get_model_architecture( ...@@ -279,6 +280,8 @@ def get_model_architecture(
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
# awq相关配置 # awq相关配置
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment