Commit 47e0607c authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix clang-format

parent e48e7f38
......@@ -4,11 +4,10 @@
#include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
// #include "gemm_f16_nn_instance.hpp"
// #include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
// #include "gemm_f16_tn_instance.hpp"
// #include "gemm_f16_tt_instance.hpp"
#include "gemm_wavelet_f16_nn_instance.hpp"
#include "gemm_wavelet_f16_nt_instance.hpp"
......@@ -73,61 +72,53 @@ using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>;
void insertNNProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
v.insert(std::end(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{false, false, true},
add_gemm_wavelet_f16_nn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
// clang-format off
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
// clang-format on
});
}
void insertNTProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
v.insert(std::end(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{false, true, true},
add_gemm_wavelet_f16_nt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
// clang-format off
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
// clang-format on
});
}
void insertTNProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
v.insert(std::end(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{true, false, true},
add_gemm_wavelet_f16_tn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
// clang-format off
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
// clang-format on
});
}
void insertTTProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
v.insert(std::end(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{true, true, true},
add_gemm_wavelet_f16_tt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
// clang-format off
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
// clang-format on
});
}
......@@ -151,9 +142,7 @@ void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout)
int main(int argc, char* argv[])
{
std::vector<ProblemDesc> problems;
// = {
// std::vector<ProblemDesc> problems = {
// clang-format off
// Use following if you run it on MI200 GPU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment