Commit 2888b4e5 authored by yangql's avatar yangql Committed by zhangzbb
Browse files

优化VLLM_DISABLE_DSA的设置,加入envs中,默认关,开启可强制关闭dsa

parent 9eff7ac1
......@@ -323,6 +323,7 @@ if TYPE_CHECKING:
USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8: bool = False
USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
......@@ -2002,7 +2003,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX":
lambda: (os.environ.get("USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX", "False").lower() in
("true", "1")),
#If set to 1/True, disenable the DSA.
"VLLM_DISABLE_DSA":
lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -3,6 +3,7 @@
from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
......@@ -554,7 +555,8 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"):
force_disable_dsa = envs.VLLM_DISABLE_DSA
if cache_config.cache_dtype.startswith("fp8") and not force_disable_dsa:
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
......
......@@ -80,7 +80,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.device = current_platform.device_type
#添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
force_disable_dsa = envs.VLLM_DISABLE_DSA
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32:
topk_tokens = config.index_topk
......
......@@ -901,7 +901,7 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
#添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
force_disable_dsa = envs.VLLM_DISABLE_DSA
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32:
......@@ -1219,7 +1219,7 @@ class DeepseekV2Model(nn.Module):
self.vocab_size = config.vocab_size
#添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
force_disable_dsa = envs.VLLM_DISABLE_DSA
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment