Unverified Commit c12df53b authored by TherLF's avatar TherLF Committed by GitHub
Browse files

[Bugfix] Fix cutlass dispatch for fp8/int8 to properly invoke M<=16 c… (#16751)


Signed-off-by: default avatarTher-LF <2639852836@qq.com>
parent d1aeea75
...@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out, ...@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) { if (mp2 <= 16) {
// M in [1, 16] // M in [1, 16]
......
...@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, ...@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) { if (mp2 <= 16) {
// M in [1, 16] // M in [1, 16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment