Commit cb1a27d2 authored by yangql's avatar yangql
Browse files

修复dsa的workspace的bug,以及添加环境变量关闭DSAVLLM_DISABLE_DSA=1

parent 4661cd18
......@@ -49,7 +49,7 @@ def sparse_attn_indexer(
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
......
......@@ -842,8 +842,9 @@ class DeepseekV2MLAAttention(nn.Module):
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.is_v32 = hasattr(config, "index_topk")
#添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32:
self.indexer_rope_emb = get_rope(
......@@ -1061,7 +1062,10 @@ class DeepseekV2Model(nn.Module):
self.device = current_platform.device_type
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
#添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
......@@ -1355,10 +1359,18 @@ class DeepseekV2ForCausalLM(
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
# 判断是否加载"indexer"权重
model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# 跳过加载"indexer"权重
if "indexer" in name and not model_has_indexer:
logger.info(f"Skipping indexer weight (DSA disabled): {name}")
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment