Commit 0786df31 authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev_yql_3.5' into 'v0.15.1-dev'

修复dsa的workspace的bug,以及添加环境变量关闭DSAVLLM_DISABLE_DSA=1

See merge request dcutoolkit/deeplearing/vllm!463
parents 4661cd18 cb1a27d2
...@@ -49,7 +49,7 @@ def sparse_attn_indexer( ...@@ -49,7 +49,7 @@ def sparse_attn_indexer(
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run # Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous( current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn), ((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16),
((total_seq_lens, 4), torch.uint8), ((total_seq_lens, 4), torch.uint8),
) )
return sparse_attn_indexer_fake( return sparse_attn_indexer_fake(
......
...@@ -842,8 +842,9 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -842,8 +842,9 @@ class DeepseekV2MLAAttention(nn.Module):
scaling_factor = config.rope_parameters["factor"] scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
#添加判断,默认开启DSA
self.is_v32 = hasattr(config, "index_topk") force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32: if self.is_v32:
self.indexer_rope_emb = get_rope( self.indexer_rope_emb = get_rope(
...@@ -1061,7 +1062,10 @@ class DeepseekV2Model(nn.Module): ...@@ -1061,7 +1062,10 @@ class DeepseekV2Model(nn.Module):
self.device = current_platform.device_type self.device = current_platform.device_type
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk") #添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
if self.is_v32: if self.is_v32:
topk_tokens = config.index_topk topk_tokens = config.index_topk
topk_indices_buffer = torch.empty( topk_indices_buffer = torch.empty(
...@@ -1355,10 +1359,18 @@ class DeepseekV2ForCausalLM( ...@@ -1355,10 +1359,18 @@ class DeepseekV2ForCausalLM(
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
# 判断是否加载"indexer"权重
model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
# 跳过加载"indexer"权重
if "indexer" in name and not model_has_indexer:
logger.info(f"Skipping indexer weight (DSA disabled): {name}")
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment