Commit 6e7c8326 authored by liuchy5's avatar liuchy5
Browse files

feat:fix dsa

parent db85ab07
...@@ -184,7 +184,7 @@ class CustomOp(nn.Module): ...@@ -184,7 +184,7 @@ class CustomOp(nn.Module):
return self.maybe_compile(self.forward_native, enable=compile_native) return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm(): if current_platform.is_rocm():
return self.forward_hip return self.forward_cuda
elif current_platform.is_cpu(): elif current_platform.is_cpu():
return self.forward_cpu return self.forward_cpu
elif current_platform.is_tpu(): elif current_platform.is_tpu():
......
...@@ -88,13 +88,13 @@ def sparse_attn_indexer( ...@@ -88,13 +88,13 @@ def sparse_attn_indexer(
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use) # Get the full shared workspace buffers once (will allocate on first use)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
workspace_manager = current_workspace_manager() workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype), ((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8), ((total_seq_lens, 4), torch.uint8),
) )
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache( ops.cp_gather_indexer_k_quant_cache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment