TORCH_CHECK(kcache.dtype()!=q_dtype,"非量化情况下, query and key must have not the same dtype");
CHECK_DEVICE(k_scale);
TORCH_CHECK(k_scale.dtype()==torch::kFloat32,"非量化情况下, query and key must have the same dtype");
TORCH_CHECK(is_gfx936,"fp8_e4m3 and fp8_e5m2 Attention Forward Kernel (mha_fwd_kvcache_quantization_mla) is only supported on gfx936 architectures");
// TORCH_CHECK(is_gfx936, "fp8_e4m3 and fp8_e5m2 Attention Forward Kernel (mha_fwd_kvcache_quantization_mla) is only supported on gfx936 architectures");
}
else
{
...
...
@@ -334,7 +336,10 @@ mha_fwd_kvcache_mla_nope_pe(
// auto dprops = at::cuda::getCurrentDeviceProperties();