Unverified Commit 61a97c32 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Kernel] Fix marlin divide-by-zero warnings (#6904)

parent 4fbf4aa1
......@@ -1128,9 +1128,11 @@ __global__ void Marlin(
};
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
if constexpr (!has_zp) {
return;
}
if constexpr (has_zp) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(group_blocks != 0);
int pipe = full_pipe % stages;
......@@ -1168,6 +1170,7 @@ __global__ void Marlin(
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
......
......@@ -452,12 +452,17 @@ __global__ void Marlin(
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
......@@ -480,7 +485,10 @@ __global__ void Marlin(
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
......
......@@ -404,12 +404,17 @@ __global__ void Marlin_24(
meta_ptr[i] += m_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
......@@ -432,7 +437,10 @@ __global__ void Marlin_24(
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
if constexpr (group_blocks != -1) {
// This assumes group_blocks >= thread_k_blocks
// and would need to be modified to support smaller groups.
static_assert(group_blocks >= thread_k_blocks);
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment