Commit 9032352f authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

resolved conflicts

parents d5c5d2a3 64d5c4d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using DsDataType = ck::Tuple<>;
using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -342,7 +342,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<int8_t>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
......
......@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
INT8_INT8_BF16, // 8
F8_F8_F16, // 9
};
#define OP_NAME "gemm_multiply_multiply"
......@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->bf16)\n");
"comp f8; 8: int8->bf16; 9: f8->f16, comp f8;)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -89,6 +90,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using I32 = int;
......@@ -165,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
......
......@@ -17,11 +17,11 @@ enum struct GemmMatrixLayout
enum struct GemmDataType
{
BF16_I8_BF16, // 0
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3
BF16_I8_BF16, // 0
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3
BF16_BF16_BF16 // 4
};
#define OP_NAME "grouped_gemm_fixed_nk"
......@@ -39,7 +39,6 @@ std::vector<int> argToIntArray(char* input)
{
out.push_back(std::stoi(item));
}
return out;
}
......@@ -83,14 +82,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
#if defined(CK_ENABLE_FP8)
using F8 = ck::f8_t;
#endif
using BF16 = ck::bhalf_t;
using I8 = int8_t;
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
......@@ -99,13 +90,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter = std::stoi(argv[16]);
}
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -123,12 +113,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -146,14 +136,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
#if defined(CK_ENABLE_FP8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::f8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -171,12 +160,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::f8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -194,14 +183,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
#endif // CK_ENABLE_FP8
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
int8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -219,12 +208,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
int8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -242,14 +231,14 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
#endif // CK_ENABLE_INT8
#if defined(CK_ENABLE_BF16)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -267,12 +256,59 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -286,11 +322,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
1,
kbatch,
n_warmup,
n_iter);
}
#endif
#endif // CK_ENABLE_INT8
#endif // CK_ENABLE_BF16
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
......
......@@ -21,16 +21,19 @@ dependencies = []
"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"]
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library", "ck4inductor.universal_gemm", "ck4inductor.batched_universal_gemm", "ck4inductor.grouped_conv_fwd"]
[tool.setuptools.package-dir]
ck4inductor = "python/ck4inductor"
"ck4inductor.universal_gemm" = "python/ck4inductor/universal_gemm"
"ck4inductor.batched_universal_gemm" = "python/ck4inductor/batched_universal_gemm"
"ck4inductor.grouped_conv_fwd" = "python/ck4inductor/grouped_conv_fwd"
"ck4inductor.include" = "include"
"ck4inductor.library" = "library"
[tool.setuptools.package-data]
"ck4inductor.include" = ["ck/**/*.hpp"]
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp", "src/tensor_operation_instance/gpu/gemm_universal_batched/**/*.hpp", "include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/**/*.hpp"]
[tool.setuptools.dynamic]
version = { attr = "setuptools_scm.get_version" }
......@@ -68,12 +68,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
try:
new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
except TypeError as e:
log.debug(f"{e} when parsing {line}")
return op_instances
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
import logging
import unittest
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_library as gen_gemm_ops_library,
)
from ck4inductor.universal_gemm.gen_instances import (
gen_ops_preselected as gen_gemm_ops_preselected,
)
from ck4inductor.grouped_conv_fwd.gen_instances import (
gen_conv_ops_library as gen_conv_ops_library,
)
from ck4inductor.batched_universal_gemm.gen_instances import (
gen_ops_library as gen_batched_gemm_ops_library,
)
log = logging.getLogger(__name__)
class TestGenInstances(unittest.TestCase):
def test_gen_gemm_instances(self):
instances = gen_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_preselected_gemm_instances(self):
instances = gen_gemm_ops_preselected()
log.debug("%d preselected gemm instances" % len(instances))
self.assertTrue(instances)
def test_gen_conv_instances(self):
instances = gen_conv_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
def test_gen_batched_gemm_instances(self):
instances = gen_batched_gemm_ops_library()
log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)
......@@ -7,6 +7,34 @@ include(gtest)
add_custom_target(tests)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set(REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}")
set(result 1)
......@@ -88,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif()
#message("add_test returns ${result}")
set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("adding to SMOKE TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction()
function(add_gtest_executable TEST_NAME)
......@@ -168,6 +205,15 @@ function(add_gtest_executable TEST_NAME)
endif()
#message("add_gtest returns ${result}")
set(result ${result} PARENT_SCOPE)
if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${TEST_NAME})
elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST")
add_dependencies(regression ${TEST_NAME})
endif()
endfunction()
add_compile_options(-Wno-c++20-extensions)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
......@@ -61,7 +61,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
......@@ -73,8 +73,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
......
......@@ -59,7 +59,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
......
......@@ -49,3 +49,4 @@ if(result EQUAL 0)
endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
add_gtest_executable(test_bhalf test_bhalf.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf_t;
using ck::type_convert;
TEST(BHALF_T, Nan)
{
const uint16_t binary_bhalf_nan = 0x7FC0;
const bhalf_t bhalf_nan = ck::bit_cast<bhalf_t>(binary_bhalf_nan);
EXPECT_EQ(bhalf_nan, type_convert<bhalf_t>(ck::NumericLimits<float>::QuietNaN()));
}
TEST(BHALF_T, Inf)
{
const uint16_t binary_bhalf_inf = 0x7F80;
const bhalf_t bhalf_inf = ck::bit_cast<bhalf_t>(binary_bhalf_inf);
EXPECT_EQ(bhalf_inf, type_convert<bhalf_t>(ck::NumericLimits<float>::Infinity()));
}
TEST(BHALF_T, MantisaOverflow)
{
const float abs_tol = std::pow(2, -7);
const uint32_t val = 0x81FFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_NEAR(float_val, type_convert<float>(type_convert<bhalf_t>(float_val)), abs_tol);
}
TEST(BHALF_T, ExpOverflow)
{
const uint32_t val = 0xFF800000;
const float float_val = ck::bit_cast<float>(val);
ASSERT_EQ(type_convert<float>(type_convert<bhalf_t>(float_val)), float_val);
}
TEST(BHALF_T, MantisaExpOverflow)
{
const uint32_t val = 0xFFFFFFFF;
const float float_val = ck::bit_cast<float>(val);
ASSERT_TRUE(std::isnan(float_val));
ASSERT_TRUE(std::isnan(type_convert<float>(type_convert<bhalf_t>(float_val))));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment