Commit 55a01eef authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into navi3x_mD_batchedGEMM_GroupConvFwd
parents f1b53d78 ba40c2ce
...@@ -185,9 +185,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm ...@@ -185,9 +185,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
}, },
Number<NumEmbeddings>{}); Number<NumEmbeddings>{});
auto out_data_refs = generate_tie( auto out_data_refs = generate_tie(
[&](auto output_index_) -> auto& { [&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
return acc_thread_buf(Number<register_offset>{});
},
Number<1>{}); Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs); unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
}); });
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <random>
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
...@@ -18,7 +19,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -18,7 +19,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
bool TestGroupedGemm() bool TestGroupedGemm()
{ {
int group_count = rand() % 10 + 1;
std::mt19937 gen(19391);
std::uniform_int_distribution<> distrib(1, 10);
int group_count = distrib(gen);
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
...@@ -29,9 +33,9 @@ bool TestGroupedGemm() ...@@ -29,9 +33,9 @@ bool TestGroupedGemm()
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
Ms.push_back(256 + 256 * (rand() % 10)); Ms.push_back(256 + 256 * distrib(gen));
Ns.push_back(256 + 256 * (rand() % 10)); Ns.push_back(256 + 256 * distrib(gen));
Ks.push_back(128 + 128 * (rand() % 10)); Ks.push_back(128 + 128 * distrib(gen));
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]); StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]); StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment