Commit 0b11569f authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute

parents e8d3a0fb fa9a0a5c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <functional>
#include <iostream>
......@@ -7,11 +10,10 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/utility/fill.hpp"
#include "profiler/include/profile_convnd_fwd.hpp"
namespace {
enum struct ConvDataType
......@@ -301,7 +303,7 @@ void profile_convnd_instances(ConvDataType data_type,
} // namespace
int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
int profile_convnd_fwd(int argc, char* argv[])
{
using namespace ck::utils::conv;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......@@ -11,10 +14,6 @@ enum struct GemmMatrixLayout
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType
......@@ -27,7 +26,7 @@ enum struct GemmDataType
int profile_gemm(int argc, char* argv[])
{
if(!(argc == 14 || argc == 15))
if(argc != 14)
{
printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
......@@ -38,9 +37,8 @@ int profile_gemm(int argc, char* argv[])
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
}
......@@ -58,350 +56,125 @@ int profile_gemm(int argc, char* argv[])
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
int KBatch = 1;
if(argc == 15)
KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using INT8 = int8_t;
using INT32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass =
ck::profiler::profile_gemm_impl<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 0;
return 1;
}
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......@@ -13,10 +16,6 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
MK_NK_MN_MN_MN, // 1
KM_KN_MN_MN_MN, // 2
KM_NK_MN_MN_MN, // 3
MK_KN_NM_MN_MN, // 4
MK_NK_NM_MN_MN, // 5
KM_KN_NM_MN_MN, // 6
KM_NK_NM_MN_MN, // 7
};
enum struct MatrixDataType
......@@ -98,17 +97,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
return ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
D1DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
ELayout>(
bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
D1DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
ELayout>(
do_verification,
init_method,
do_log,
......@@ -121,6 +120,8 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
(StrideE < 0) ? DefaultStrideE : StrideE);
return pass ? 0 : 1;
};
if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
......@@ -146,6 +147,6 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 0;
return 1;
}
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/include/profile_gemm_splitk_impl.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
int profile_gemm_splitk(int argc, char* argv[])
{
if(argc != 15)
{
printf("arg1: tensor operation (gemm_splitk: Split-K GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
using F32 = float;
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/include/profile_normalization_impl.hpp"
using ck::index_t;
using ck::profiler::NormDataType;
using ck::profiler::NormType;
struct ArgParser
{
std::unordered_map<std::string, NormType> norm_dict = {{"layernorm", NormType::LAYERNORM},
{"batchnorm", NormType::BATCHNORM},
{"softmax", NormType::SOFTMAX}};
std::unordered_map<std::string, std::vector<int>> long_opts = {
{"length", {}}, {"stride", {}}, {"reduce", {}}, {"alpha", {}}, {"beta", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help()
{
std::cout << "arg1: tensor operation (layernorm/batchnorm/softmax)\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
<< "arg3: verification (0: no; 1: yes)\n"
<< "arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg5: print tensor value (0: no; 1: yes)\n"
<< "arg6: time kernel (0=n0, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 8 4 256) \n"
<< "--stride: tensor strides (e.g, --stride 1024 256 1)\n"
<< "--reduce: to-reduce dimensions (e.g, --reduce 2)\n"
<< "--alpha: alpha scaling value\n"
<< "--beta: beta scaling value\n"
<< std::endl;
}
int profile_normalization(int argc, char* argv[])
{
if(argc <= 2)
{
print_help();
return 0;
}
ArgParser arg_parser;
// short unnamed options
const NormType norm_type = arg_parser.norm_dict[argv[1]];
const NormDataType data_type = static_cast<NormDataType>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
const std::vector<index_t> stride = arg_parser.long_opts["stride"];
const std::vector<index_t> reduce = arg_parser.long_opts["reduce"];
const index_t alpha =
arg_parser.long_opts["alpha"].empty() ? 1 : arg_parser.long_opts["alpha"][0];
const index_t beta = arg_parser.long_opts["beta"].empty() ? 0 : arg_parser.long_opts["beta"][0];
if(data_type == NormDataType::F16_F16)
{
ck::profiler::profile_normalization_impl<ck::half_t, float, ck::half_t>(do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
}
else if(data_type == NormDataType::F32_F32)
{
ck::profiler::profile_normalization_impl<float, float, float>(do_verification,
init_method,
do_log,
time_kernel,
length,
stride,
reduce,
float(alpha),
float(beta),
norm_type);
}
else
{
throw std::runtime_error("not implemented yet");
}
return 0;
}
// hijack main() for quick debugging
// int main(int argc, char* argv[])
// {
// profile_normalization(argc, argv);
// return 0;
// }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <fstream>
#include <cstdlib>
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <cstring>
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "profiler/include/profile_convnd_fwd.hpp"
#include <cstring>
int profile_gemm(int, char*[]);
int profile_gemm_splitk(int, char*[]);
int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_gemm_bias_add_reduce(int, char*[]);
int profile_gemm_add_add_fastgelu(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_convnd_fwd(int argc, char* argv[]);
int profile_convnd_bwd_data(int, char*[], int);
int profile_reduce(int, char*[]);
int profile_conv_bwd_weight(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]);
int profile_gemm_add_add_fastgelu(int, char*[]);
int profile_normalization(int, char*[]);
int profile_reduce(int, char*[]);
static void print_helper_message()
{
// clang-format off
printf("arg1: tensor operation (gemm: GEMM\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
" reduce: Reduce\n"
" conv2d_bwd_weight: Backward Weight Convolution 2d\n"
" gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n");
printf("arg1: tensor operation (gemm: GEMM\n"
" gemm_splitk: Split-K GEMM\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n"
" gemm_reduce: GEMM+Reduce\n"
" batched_gemm: Batched GEMM\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
" conv2d_bwd_weight: Backward Weight Convolution 2d\n"
" reduce: Reduce\n");
// clang-format on
}
......@@ -57,6 +59,10 @@ int main(int argc, char* argv[])
{
return profile_gemm(argc, argv);
}
else if(strcmp(argv[1], "gemm_splitk") == 0)
{
return profile_gemm_splitk(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
return profile_gemm_bias_2d(argc, argv);
......@@ -91,7 +97,7 @@ int main(int argc, char* argv[])
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return ck::profiler::profile_convnd_fwd(argc, argv);
return profile_convnd_fwd(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{
......@@ -125,6 +131,11 @@ int main(int argc, char* argv[])
{
return profile_gemm_add_add_fastgelu(argc, argv);
}
else if(strcmp(argv[1], "batchnorm") == 0 || strcmp(argv[1], "layernorm") == 0 ||
strcmp(argv[1], "softmax") == 0)
{
return profile_normalization(argc, argv);
}
else
{
print_helper_message();
......
......@@ -13,6 +13,7 @@ function(add_test_executable TEST_NAME)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> )
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
endfunction(add_test_executable TEST_NAME)
include(GoogleTest)
......@@ -26,6 +27,7 @@ function(add_gtest_executable TEST_NAME)
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
gtest_discover_tests(${TEST_NAME})
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
endfunction(add_gtest_executable TEST_NAME)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/include/profile_batched_gemm_impl.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment