Commit 47e0607c authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix clang-format

parent e48e7f38
...@@ -4,11 +4,10 @@ ...@@ -4,11 +4,10 @@
#include "gemm_util.hpp" #include "gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
// #include "gemm_f16_nn_instance.hpp" // #include "gemm_f16_nn_instance.hpp"
// #include "gemm_f16_nt_instance.hpp" // #include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp" // #include "gemm_f16_tn_instance.hpp"
// #include "gemm_f16_tt_instance.hpp" // #include "gemm_f16_tt_instance.hpp"
#include "gemm_wavelet_f16_nn_instance.hpp" #include "gemm_wavelet_f16_nn_instance.hpp"
#include "gemm_wavelet_f16_nt_instance.hpp" #include "gemm_wavelet_f16_nt_instance.hpp"
...@@ -73,61 +72,53 @@ using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>; ...@@ -73,61 +72,53 @@ using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>;
void insertNNProblems(std::vector<ProblemDesc>& v) void insertNNProblems(std::vector<ProblemDesc>& v)
{ {
v.insert(std::begin(v), v.insert(std::end(v),
{ {
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, // clang-format off
// add_gemm_wavelet_f16_nn_256x256}, {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128},
LayoutConfig{false, false, true}, {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128},
add_gemm_wavelet_f16_nn_256x128}, {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, // clang-format on
// add_gemm_wavelet_f16_nn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
}); });
} }
void insertNTProblems(std::vector<ProblemDesc>& v) void insertNTProblems(std::vector<ProblemDesc>& v)
{ {
v.insert(std::begin(v), v.insert(std::end(v),
{ {
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, // clang-format off
// add_gemm_wavelet_f16_nt_256x256}, {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128},
LayoutConfig{false, true, true}, {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128},
add_gemm_wavelet_f16_nt_256x128}, {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, // clang-format on
// add_gemm_wavelet_f16_nt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
}); });
} }
void insertTNProblems(std::vector<ProblemDesc>& v) void insertTNProblems(std::vector<ProblemDesc>& v)
{ {
v.insert(std::begin(v), v.insert(std::end(v),
{ {
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, // clang-format off
// add_gemm_wavelet_f16_tn_256x256}, {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
LayoutConfig{true, false, true}, {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
add_gemm_wavelet_f16_tn_256x128}, {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, // clang-format on
// add_gemm_wavelet_f16_tn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
}); });
} }
void insertTTProblems(std::vector<ProblemDesc>& v) void insertTTProblems(std::vector<ProblemDesc>& v)
{ {
v.insert(std::begin(v), v.insert(std::end(v),
{ {
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, // clang-format off
// add_gemm_wavelet_f16_tt_256x256}, {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128},
LayoutConfig{true, true, true}, {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128},
add_gemm_wavelet_f16_tt_256x128}, {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, // clang-format on
// add_gemm_wavelet_f16_tt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
}); });
} }
...@@ -151,9 +142,7 @@ void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout) ...@@ -151,9 +142,7 @@ void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
std::vector<ProblemDesc> problems; // std::vector<ProblemDesc> problems = {
// = {
// clang-format off // clang-format off
// Use following if you run it on MI200 GPU // Use following if you run it on MI200 GPU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment