Commit 2d6fe2cd authored by Adam Osewski's avatar Adam Osewski
Browse files

Add argument to choose input layout.

parent fc1edc32
......@@ -54,16 +54,106 @@ struct LayoutConfig
bool CRowMajor;
};
enum class ABDataLayout : int
{
NN,
NT,
TN,
TT,
ALL,
};
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using OpFactoryFn = void (*)(std::vector<std::unique_ptr<BaseOperator>>&);
using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>;
void insertNNProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{false, false, true},
add_gemm_wavelet_f16_nn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
});
}
void insertNTProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{false, true, true},
add_gemm_wavelet_f16_nt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
});
}
void insertTNProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{true, false, true},
add_gemm_wavelet_f16_tn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
});
}
void insertTTProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{true, true, true},
add_gemm_wavelet_f16_tt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
});
}
void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout)
{
switch(layout)
{
case ABDataLayout::NN: insertNNProblems(v); break;
case ABDataLayout::NT: insertNTProblems(v); break;
case ABDataLayout::TN: insertTNProblems(v); break;
case ABDataLayout::TT: insertTTProblems(v); break;
case ABDataLayout::ALL:
default:
insertNNProblems(v);
insertNTProblems(v);
insertTNProblems(v);
insertTTProblems(v);
};
}
int main(int argc, char* argv[])
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using OpFactoryFn = void (*)(std::vector<std::unique_ptr<BaseOperator>>&);
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
std::vector<ProblemDesc> problems;
// = {
// clang-format off
// Use following if you run it on MI200 GPU
......@@ -87,22 +177,22 @@ int main(int argc, char* argv[])
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// wavelet
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64},
// 110 tiles
......@@ -123,27 +213,32 @@ int main(int argc, char* argv[])
// {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on
};
// };
bool do_verification = true;
bool time_kernel = true;
auto input_layout = ABDataLayout::ALL;
if(argc == 1)
{
// use default
}
else if(argc == 3)
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
input_layout = ABDataLayout{std::stoi(argv[3])};
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl;
<< "arg2: time kernel (0=no, 1=yes)" << std::endl
<< "arg3: Input data layout (0=NN, 1=NT, 2=TN, 3=TT)" << std::endl;
return 0;
}
get_problems(problems, input_layout);
bool pass = true;
for(auto& p : problems)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment