Commit 451af742 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix weights_not_loaded

update weights_not_loaded and flash_mla_with_kvcache
update paged_mqa_logits
parent aa05dfd5
...@@ -313,9 +313,9 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -313,9 +313,9 @@ class DefaultModelLoader(BaseModelLoader):
# We only enable strict check for non-quantized models # We only enable strict check for non-quantized models
# that have loaded weights tracking currently. # that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None: if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights weights_not_loaded = {k for k in weights_not_loaded if not k.endswith("indexer.weights_proj.bias")}
if weights_not_loaded: if weights_not_loaded:
raise ValueError( raise ValueError(
"Following weights were not initialized from " "Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}" f"checkpoint: {weights_not_loaded}"
) )
\ No newline at end of file
...@@ -336,6 +336,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -336,6 +336,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True, causal=True,
descale_q=layer._q_scale.reshape(1), descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1), descale_k=layer._k_scale.reshape(1),
is_fp8_kvcache=False,
indices= None,
) )
o = reshape_attn_output_for_spec_decode(o) o = reshape_attn_output_for_spec_decode(o)
......
...@@ -315,7 +315,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -315,7 +315,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported(): if is_deep_gemm_supported():
if current_platform.is_rocm(): if current_platform.is_rocm():
self.scheduler_metadata_buffer[:] = gemmopt.get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment