Commit 84f83034 authored by zhanghj2's avatar zhanghj2
Browse files

fix h_q < 16 bug

parent 40f4bf39
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <cutlass/fast_math.h>
#include "params.h" #include "params.h"
...@@ -60,13 +61,13 @@ public: ...@@ -60,13 +61,13 @@ public:
Arch arch = Arch(); Arch arch = Arch();
if (h_q <= 16) { if (h_q <= 16) {
return { return {
std::max(arch.num_sms * 2 / s_q / (h_q/16), 1), std::max(arch.num_sms * 2 / s_q / cutlass::ceil_div(h_q, 16), 1),
5, 5,
64 64
}; };
} }
return { return {
std::max(arch.num_sms / s_q / (h_q/64), 1), std::max(arch.num_sms / s_q / cutlass::ceil_div(h_q, 64), 1),
5, 5,
64 64
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment