Unverified Commit 4c22ebe2 authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

Disable kernel cutlass_mla_decode on SM103 (#10058)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
parent a5a03209
...@@ -26,6 +26,7 @@ limitations under the License. ...@@ -26,6 +26,7 @@ limitations under the License.
#include "cutlass_sm100_mla/device/sm100_mla.hpp" #include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" #include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
#include "utils.h"
// clang-format off // clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
...@@ -217,6 +218,10 @@ void cutlass_mla_decode( ...@@ -217,6 +218,10 @@ void cutlass_mla_decode(
torch::Tensor const& workspace, torch::Tensor const& workspace,
double sm_scale, double sm_scale,
int64_t num_kv_splits) { int64_t num_kv_splits) {
auto sm_version = getSMVersion();
// On SM103a, half of the accuracy tests are failing.
TORCH_CHECK(sm_version == 100, "cutlass_mla_decode is only supported on compute capability 10.0, but found sm version ", sm_version);
auto in_dtype = q_nope.dtype(); auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
......
...@@ -4,9 +4,10 @@ import torch.nn.functional as F ...@@ -4,9 +4,10 @@ import torch.nn.functional as F
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
from torch import Tensor from torch import Tensor
if torch.cuda.get_device_capability() < (10, 0): # Disable tests on SM103 until the accuracy issues are fixed.
if torch.cuda.get_device_capability() != (10, 0):
pytest.skip( pytest.skip(
reason="Cutlass MLA Requires compute capability of 10 or above.", reason="Cutlass MLA Requires compute capability of 10.",
allow_module_level=True, allow_module_level=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment