Commit 451af742 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix weights_not_loaded

update weights_not_loaded and flash_mla_with_kvcache
update paged_mqa_logits
parent aa05dfd5
......@@ -313,7 +313,7 @@ class DefaultModelLoader(BaseModelLoader):
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
weights_not_loaded = {k for k in weights_not_loaded if not k.endswith("indexer.weights_proj.bias")}
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
......
......@@ -336,6 +336,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
is_fp8_kvcache=False,
indices= None,
)
o = reshape_attn_output_for_spec_decode(o)
......
......@@ -315,7 +315,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported():
if current_platform.is_rocm():
self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata(
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment