Commit 2d6fe2cd authored by Adam Osewski's avatar Adam Osewski
Browse files

Add argument to choose input layout.

parent fc1edc32
...@@ -54,20 +54,110 @@ struct LayoutConfig ...@@ -54,20 +54,110 @@ struct LayoutConfig
bool CRowMajor; bool CRowMajor;
}; };
enum class ABDataLayout : int
{
NN,
NT,
TN,
TT,
ALL,
};
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using OpFactoryFn = void (*)(std::vector<std::unique_ptr<BaseOperator>>&);
using ProblemDesc = std::tuple<GemmParams, LayoutConfig, OpFactoryFn>;
void insertNNProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{false, false, true},
add_gemm_wavelet_f16_nn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true},
// add_gemm_wavelet_f16_nn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}
});
}
void insertNTProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{false, true, true},
add_gemm_wavelet_f16_nt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true},
// add_gemm_wavelet_f16_nt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}
});
}
void insertTNProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{true, false, true},
add_gemm_wavelet_f16_tn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true},
// add_gemm_wavelet_f16_tn_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}
});
}
void insertTTProblems(std::vector<ProblemDesc>& v)
{
v.insert(std::begin(v),
{
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096},
LayoutConfig{true, true, true},
add_gemm_wavelet_f16_tt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true},
// add_gemm_wavelet_f16_tt_128x128}, {GemmParams{1024, 832, 4096},
// LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}
});
}
void get_problems(std::vector<ProblemDesc>& v, ABDataLayout layout)
{
switch(layout)
{
case ABDataLayout::NN: insertNNProblems(v); break;
case ABDataLayout::NT: insertNTProblems(v); break;
case ABDataLayout::TN: insertTNProblems(v); break;
case ABDataLayout::TT: insertTTProblems(v); break;
case ABDataLayout::ALL:
default:
insertNNProblems(v);
insertNTProblems(v);
insertTNProblems(v);
insertTTProblems(v);
};
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
// them in a single vector. Instead we use abstract BaseOperator class and dynamic_cast() it std::vector<ProblemDesc> problems;
// upon invocation.
// And since DeviceGemm does not expose template arg information, an extra book keeping class // = {
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to. // clang-format off
using OpFactoryFn = void (*)(std::vector<std::unique_ptr<BaseOperator>>&);
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
// clang-format off
// Use following if you run it on MI200 GPU // Use following if you run it on MI200 GPU
// 104 tiles // 104 tiles
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
...@@ -87,22 +177,22 @@ int main(int argc, char* argv[]) ...@@ -87,22 +177,22 @@ int main(int argc, char* argv[])
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// wavelet // wavelet
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_wavelet_f16_nn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_wavelet_f16_nt_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_wavelet_f16_tt_128x64},
// 110 tiles // 110 tiles
...@@ -122,28 +212,33 @@ int main(int argc, char* argv[]) ...@@ -122,28 +212,33 @@ int main(int argc, char* argv[])
// {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, // {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, // {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, // {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on // clang-format on
}; // };
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
auto input_layout = ABDataLayout::ALL;
if(argc == 1) if(argc == 1)
{ {
// use default // use default
} }
else if(argc == 3) else if(argc == 4)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]); time_kernel = std::stoi(argv[2]);
input_layout = ABDataLayout{std::stoi(argv[3])};
} }
else else
{ {
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl; << "arg2: time kernel (0=no, 1=yes)" << std::endl
<< "arg3: Input data layout (0=NN, 1=NT, 2=TN, 3=TT)" << std::endl;
return 0; return 0;
} }
get_problems(problems, input_layout);
bool pass = true; bool pass = true;
for(auto& p : problems) for(auto& p : problems)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment