Unverified Commit 3f710930 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Update default stride (#1576)

* Update default stride value to -1

* Fix format

* Revert "Fix format"

This reverts commit ae0c3649

.

---------
Co-authored-by: default avatarHarisankar Sadasivan <135730918+hsadasiv@users.noreply.github.com>
parent 794f2d64
...@@ -29,9 +29,9 @@ struct ProblemSize final ...@@ -29,9 +29,9 @@ struct ProblemSize final
ck::index_t N = 4096; ck::index_t N = 4096;
ck::index_t K = 4096; ck::index_t K = 4096;
ck::index_t StrideA = 0; ck::index_t StrideA = -1;
ck::index_t StrideB = 0; ck::index_t StrideB = -1;
ck::index_t StrideC = 0; ck::index_t StrideC = -1;
}; };
struct ProblemSizeStreamK final struct ProblemSizeStreamK final
...@@ -40,9 +40,9 @@ struct ProblemSizeStreamK final ...@@ -40,9 +40,9 @@ struct ProblemSizeStreamK final
ck::index_t N = 4096; ck::index_t N = 4096;
ck::index_t K = 4096; ck::index_t K = 4096;
ck::index_t StrideA = 0; ck::index_t StrideA = -1;
ck::index_t StrideB = 0; ck::index_t StrideB = -1;
ck::index_t StrideC = 0; ck::index_t StrideC = -1;
ck::index_t NumSKBlocks = -1; ck::index_t NumSKBlocks = -1;
}; };
...@@ -52,9 +52,9 @@ struct ProblemSizeStreamK_universal final ...@@ -52,9 +52,9 @@ struct ProblemSizeStreamK_universal final
ck::index_t N = 4096; ck::index_t N = 4096;
ck::index_t K = 4096; ck::index_t K = 4096;
ck::index_t StrideA = 0; ck::index_t StrideA = -1;
ck::index_t StrideB = 0; ck::index_t StrideB = -1;
ck::index_t StrideC = 0; ck::index_t StrideC = -1;
ck::index_t Grid_size = -1; // defaults to max occupancy ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
...@@ -66,9 +66,9 @@ struct ProblemSizeSplitK final ...@@ -66,9 +66,9 @@ struct ProblemSizeSplitK final
ck::index_t N = 4096; ck::index_t N = 4096;
ck::index_t K = 4096; ck::index_t K = 4096;
ck::index_t StrideA = 0; ck::index_t StrideA = -1;
ck::index_t StrideB = 0; ck::index_t StrideB = -1;
ck::index_t StrideC = 0; ck::index_t StrideC = -1;
ck::index_t KBatch = 1; ck::index_t KBatch = 1;
}; };
......
...@@ -116,21 +116,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -116,21 +116,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}; };
auto f_get_default_stride = auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == 0) if(stride == -1)
{ {
// give a chance if stride is zero, return a default packed stride // give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return col; return static_cast<std::size_t>(col);
} }
else else
{ {
return row; return static_cast<std::size_t>(row);
} }
} }
else else
return stride; return static_cast<std::size_t>(stride);
}; };
StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
......
...@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride = auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == 0) if(stride == -1)
{ {
// give a chance if stride is 0, return a default packed stride // give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return static_cast<std::size_t>(col); return static_cast<std::size_t>(col);
......
...@@ -115,21 +115,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -115,21 +115,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}; };
auto f_get_default_stride = auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == 0) if(stride == -1)
{ {
// give a chance if stride is zero, return a default packed stride // give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return col; return static_cast<std::size_t>(col);
} }
else else
{ {
return row; return static_cast<std::size_t>(row);
} }
} }
else else
return stride; return static_cast<std::size_t>(stride);
}; };
StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment