#include #include namespace spconv { template using determine_half_t = std::conditional_t::value, cutlass::half_t, T>; void cutlass_mm_out(cudaStream_t stream, torch::Tensor c, torch::Tensor a, torch::Tensor b) { TV_ASSERT_RT_ERR(c.dtype() == a.dtype() && c.dtype() == b.dtype(), "dtype must be same"); TV_ASSERT_RT_ERR(c.is_contiguous() && b.is_contiguous() && a.is_contiguous(), "error"); auto M = a.size(0); auto K = a.size(1); auto N = b.size(1); TV_ASSERT_RT_ERR(b.size(0) == K && c.size(0) == M && c.size(1) == N, "error"); tv::dispatch_torch(c.scalar_type(), [&](auto I) { using T = decltype(I); using HalfT = determine_half_t; auto status = cutlassGemm( stream, M, N, K, HalfT(1.0), reinterpret_cast(a.data_ptr()), a.size(1), reinterpret_cast(b.data_ptr()), b.size(1), HalfT(0.0), reinterpret_cast(c.data_ptr()), c.size(1)); TV_ASSERT_RT_ERR(status == cudaSuccess, "error"); }); } void cutlass_mm_out(torch::Tensor c, torch::Tensor a, torch::Tensor b) { auto stream = at::cuda::getCurrentCUDAStream(); return cutlass_mm_out(stream, c, a, b); } } // namespace spconv