"local_mode/README.md" did not exist on "18493eefc09cd47a6d47da3af0d73cbee063de9f"
Commit 97b32147 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Fix strides

parent 8e89c4d9
...@@ -358,15 +358,15 @@ struct GemmParams ...@@ -358,15 +358,15 @@ struct GemmParams
* *
* A[16x128] * B[128x16] = C[16x16], all row major. * A[16x128] * B[128x16] = C[16x16], all row major.
*/ */
GemmParams() : M(16), N(16), K(128), StrideA(128), StrideB(16), StrideC(16) {} GemmParams() : M(16), N(16), K(128) {}
ck::index_t M; ck::index_t M;
ck::index_t N; ck::index_t N;
ck::index_t K; ck::index_t K;
ck::index_t StrideA; ck::index_t StrideA = -1;
ck::index_t StrideB; ck::index_t StrideB = -1;
ck::index_t StrideC; ck::index_t StrideC = -1;
}; };
template <typename GemmInstance, template <typename GemmInstance,
...@@ -465,9 +465,30 @@ struct TestMFMA ...@@ -465,9 +465,30 @@ struct TestMFMA
params.M = BLOCK_M; params.M = BLOCK_M;
params.N = BLOCK_N; params.N = BLOCK_N;
params.K = BLOCK_K; params.K = BLOCK_K;
params.StrideA = BLOCK_K; // M K
params.StrideB = BLOCK_N; // K N auto f_get_default_stride = [](std::size_t row,
params.StrideC = BLOCK_N; // M N std::size_t col,
ck::index_t stride,
auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{});
params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{});
params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{});
auto host_tensors = PrepareGemmTensors(params); auto host_tensors = PrepareGemmTensors(params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment