// SPDX-License-Identifier: MIT #include #include #include #include #include #include #include "ck_grouped_gemm_abi.h" #include "grouped_gemm_ck.h" namespace { int dtype_to_grouped_gemm_dtype(const at::ScalarType dtype) { switch(dtype) { case at::ScalarType::Half: return CK_TILE_DCU_GROUPED_GEMM_FP16; case at::ScalarType::BFloat16: return CK_TILE_DCU_GROUPED_GEMM_BF16; case at::ScalarType::Float8_e4m3fn: return CK_TILE_DCU_GROUPED_GEMM_FP8; case at::ScalarType::Char: return CK_TILE_DCU_GROUPED_GEMM_INT8; default: TORCH_CHECK(false, "ck_grouped_gemm: unsupported dtype: ", dtype); } } at::ScalarType output_dtype(const at::ScalarType dtype) { if(dtype == at::ScalarType::Char) { return at::ScalarType::Int; } if(dtype == at::ScalarType::Float8_e4m3fn) { return at::ScalarType::Float; } return dtype; } void check_grouped_gemm_tensor(const torch::Tensor& t, const char* name) { TORCH_CHECK(t.is_cuda(), "ck_grouped_gemm: ", name, " must be a CUDA tensor"); TORCH_CHECK(t.dim() == 2, "ck_grouped_gemm: ", name, " tensors must be 2D"); TORCH_CHECK(t.is_contiguous(), "ck_grouped_gemm: ", name, " tensors must be contiguous"); } void check_supported_shape(const at::ScalarType dtype, int64_t m, int64_t n, int64_t k) { if(dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { TORCH_CHECK(n % 128 == 0 && k % 64 == 0 && k >= 128, "ck_grouped_gemm: fp16/bf16 requires N % 128 == 0, K % 64 == 0, K >= 128"); } else if(dtype == at::ScalarType::Float8_e4m3fn) { TORCH_CHECK(n % 128 == 0 && k % 128 == 0, "ck_grouped_gemm: fp8 requires N % 128 == 0, K % 128 == 0"); } else if(dtype == at::ScalarType::Char) { TORCH_CHECK(m % 32 == 0 && n % 32 == 0 && k % 128 == 0, "ck_grouped_gemm: int8 requires M % 32 == 0, N % 32 == 0, K % 128 == 0"); } } torch::Tensor grouped_gemm_workspace(const at::Device& device, int64_t nbytes) { static at::Device cached_device = at::Device(at::DeviceType::CUDA, -1); static int64_t cached_nbytes = 0; static torch::Tensor cached_buffer; if(cached_device != device || cached_nbytes < nbytes) { cached_device = device; cached_nbytes = nbytes; cached_buffer = torch::empty({nbytes}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); } return cached_buffer; } } // namespace std::vector ck_grouped_gemm_impl(std::vector& a_tensors, std::vector& b_tensors, std::vector* c_tensors_out) { TORCH_CHECK(!a_tensors.empty(), "ck_grouped_gemm: a tensor list must not be empty"); TORCH_CHECK(a_tensors.size() == b_tensors.size(), "ck_grouped_gemm: a and b tensor lists must have the same length"); if(c_tensors_out != nullptr) { TORCH_CHECK(c_tensors_out->size() == a_tensors.size(), "ck_grouped_gemm: c tensor list must match a/b length"); } TORCH_CHECK(a_tensors.size() <= static_cast(std::numeric_limits::max()), "ck_grouped_gemm: group count exceeds int range expected by CK C ABI"); const auto dtype = a_tensors[0].scalar_type(); const auto device = a_tensors[0].device(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(a_tensors[0])); const auto c_dtype = output_dtype(dtype); std::vector outputs; outputs.reserve(a_tensors.size()); std::vector descs; descs.reserve(a_tensors.size()); for(std::size_t i = 0; i < a_tensors.size(); ++i) { auto& a = a_tensors[i]; auto& b = b_tensors[i]; check_grouped_gemm_tensor(a, "a"); check_grouped_gemm_tensor(b, "b"); TORCH_CHECK(a.device() == device && b.device() == device, "ck_grouped_gemm: all tensors must be on the same device"); TORCH_CHECK(a.scalar_type() == dtype && b.scalar_type() == dtype, "ck_grouped_gemm: all a/b tensors must have the same dtype"); const int64_t m = a.size(0); const int64_t k = a.size(1); const int64_t n = b.size(0); TORCH_CHECK(b.size(1) == k, "ck_grouped_gemm: K mismatch at group ", i); TORCH_CHECK(m > 0 && n > 0 && k > 0, "ck_grouped_gemm: all dimensions must be positive"); check_supported_shape(dtype, m, n, k); TORCH_CHECK(m <= std::numeric_limits::max() && n <= std::numeric_limits::max() && k <= std::numeric_limits::max(), "ck_grouped_gemm: dimensions exceed int range expected by CK C ABI"); torch::Tensor c; if(c_tensors_out != nullptr) { c = c_tensors_out->at(i); check_grouped_gemm_tensor(c, "c"); TORCH_CHECK(c.device() == device, "ck_grouped_gemm: all c tensors must be on the same device"); TORCH_CHECK(c.scalar_type() == c_dtype, "ck_grouped_gemm: c tensor dtype mismatch at group ", i); TORCH_CHECK(c.size(0) == m && c.size(1) == n, "ck_grouped_gemm: c tensor shape mismatch at group ", i, ", expected [", m, ", ", n, "]"); } else { c = torch::empty({m, n}, torch::TensorOptions().dtype(c_dtype).device(device)); } outputs.push_back(c); descs.push_back(ck_tile_dcu_grouped_gemm_desc{a.data_ptr(), b.data_ptr(), c.data_ptr(), 1, static_cast(m), static_cast(n), static_cast(k), static_cast(k), static_cast(k), static_cast(n), 0, nullptr, nullptr}); } const auto workspace_bytes = ck_tile_dcu_grouped_gemm_workspace_size(static_cast(descs.size()), 0); auto workspace = grouped_gemm_workspace(device, static_cast(workspace_bytes)); const hipStream_t stream = at::hip::getCurrentHIPStream(); const int rc = ck_tile_dcu_grouped_gemm_run(descs.data(), static_cast(descs.size()), dtype_to_grouped_gemm_dtype(dtype), 'R', 'C', workspace.data_ptr(), stream); TORCH_CHECK(rc == 0, "ck_grouped_gemm: CK C ABI returned error ", rc); return outputs; } std::vector ck_grouped_gemm(std::vector& a_tensors, std::vector& b_tensors) { return ck_grouped_gemm_impl(a_tensors, b_tensors, nullptr); } std::vector ck_grouped_gemm_out(std::vector& a_tensors, std::vector& b_tensors, std::vector& c_tensors) { return ck_grouped_gemm_impl(a_tensors, b_tensors, &c_tensors); }