Commit 7c8db5e7 authored by yangql's avatar yangql
Browse files

修复get_gcn_arch_name的导入bug

parent c6a45c08
......@@ -98,7 +98,7 @@ def sparse_attn_indexer(
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or get_gcn_arch_name == "gfx938" else k.dtype,),
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or get_gcn_arch_name() == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment