"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "3d330c4c095b78b3e6226d99f4d4a7a0965f3758"
Unverified Commit 202f147c authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

Fix MLA runs when use_inductor_graph_partition=True (#38631)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
parent ea7bfde6
...@@ -929,13 +929,15 @@ def unified_mla_kv_cache_update( ...@@ -929,13 +929,15 @@ def unified_mla_kv_cache_update(
the data dependency between them to ensure torch.compile preserves ordering. the data dependency between them to ensure torch.compile preserves ordering.
""" """
forward_context = get_forward_context() forward_context = get_forward_context()
if forward_context.attn_metadata is None:
# Dummy/profile forwards should not update live KV cache pages.
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
attn_layer = forward_context.no_compile_layers[layer_name] attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache kv_cache = attn_layer.kv_cache
# This needs to run even when we don't have metadata yet, so that the op
# is correctly captured.
if kv_cache.numel() == 0:
# Can't update an empty KV cache.
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
slot_mapping = forward_context.slot_mapping slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), ( assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment