Unverified Commit f83de719 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix OOB read in CUTLASS grouped GEMM with epilogue (#38571)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 445a2a4d
......@@ -389,20 +389,28 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void
begin() {
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col_array[group]));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
// tCgCol has layout (CPY,CPY_M,CPY_N,EPI_M,EPI_N) where CPY_N and
// EPI_N are stride-0 for the column broadcast. Slice those modes at
// index 0 to avoid redundant copies AND ensure pred/data consistency
static_assert(decltype(stride<2>(tCgCol))::value == 0, "Expected stride-0 CPY_N for col broadcast");
static_assert(decltype(stride<4>(tCgCol))::value == 0, "Expected stride-0 EPI_N for col broadcast");
auto tCgCol_s = tCgCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M)
auto tCrCol_s = tCrCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M)
auto tCcCol_s = tCcCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M)
cute::Tensor pred = make_tensor<bool>(shape(tCgCol_s));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol_s(i)) < m;
}
copy_if(pred, tCgCol_s, tCrCol_s);
}
template <typename ElementAccumulator, int FragmentSize>
......
......@@ -382,20 +382,28 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void
begin() {
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
// tCgCol has layout (CPY,CPY_M,CPY_N,EPI_M,EPI_N) where CPY_N and
// EPI_N are stride-0 for the column broadcast. Slice those modes at
// index 0 to avoid redundant copies AND ensure pred/data consistency
static_assert(decltype(stride<2>(tCgCol))::value == 0, "Expected stride-0 CPY_N for col broadcast");
static_assert(decltype(stride<4>(tCgCol))::value == 0, "Expected stride-0 EPI_N for col broadcast");
auto tCgCol_s = tCgCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M)
auto tCrCol_s = tCrCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M)
auto tCcCol_s = tCcCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M)
cute::Tensor pred = make_tensor<bool>(shape(tCgCol_s));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol_s(i)) < m;
}
copy_if(pred, tCgCol_s, tCrCol_s);
}
template <typename ElementAccumulator, int FragmentSize>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment