Commit 6e7c8326 authored by liuchy5's avatar liuchy5
Browse files

feat:fix dsa

parent db85ab07
......@@ -184,7 +184,7 @@ class CustomOp(nn.Module):
return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm():
return self.forward_hip
return self.forward_cuda
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_tpu():
......
......@@ -88,13 +88,13 @@ def sparse_attn_indexer(
prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment