Commit ce72f286 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 50320413 f30e5975
......@@ -21,6 +21,8 @@ template <typename XDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename SaveMeanInvStdDataType,
bool SaveMeanInvStd,
index_t Rank>
bool profile_layernorm_impl(int do_verification,
int init_method,
......@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification,
Tensor<GammaDataType> gamma(reduce_length);
Tensor<BetaDataType> beta(reduce_length);
Tensor<YDataType> y(length);
Tensor<SaveMeanInvStdDataType> save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> save_inv_std({length[0]});
Tensor<YDataType> host_y(length);
Tensor<SaveMeanInvStdDataType> host_save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> host_save_inv_std({length[0]});
std::vector<index_t> strideXY =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideGammaBeta = strideXY;
strideGammaBeta[0] = 0;
std::vector<index_t> strideSaveMeanInvStd = {1};
switch(init_method)
{
case 0:
......@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification,
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
......@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim>;
......@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification)
{
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernorm<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, reduce_dim, 1e-4);
auto ref_invoker = ref.MakeInvoker();
auto ref_argument = ref.MakeArgument(x,
gamma,
beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
length,
reduce_dim,
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
auto f_get_argument = [&](auto& inst_ptr) {
if constexpr(SaveMeanInvStd)
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
PassThrough{});
else
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
};
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
auto argument_ptr = f_get_argument(inst_ptr);
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
......@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification,
beta.mDesc.GetElementSize() * sizeof(BetaDataType) +
y.mDesc.GetElementSize() * sizeof(YDataType);
if constexpr(SaveMeanInvStd)
num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) +
save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
......@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification)
{
y_dev.FromDevice(y.mData.data());
bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if constexpr(SaveMeanInvStd)
{
save_mean_dev.FromDevice(save_mean.mData.data());
pass &= ck::utils::check_err(
save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3);
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(save_inv_std.mData,
host_save_inv_std.mData,
"Error: Incorrect results",
1e-3,
1e-3);
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl;
......
......@@ -25,8 +25,6 @@ set(PROFILER_SOURCES
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
)
......@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler)
add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
......@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
......@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif()
if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
......
......@@ -25,6 +25,7 @@ enum struct GemmDataType
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
};
#define OP_NAME "gemm_splitk"
......@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
if(argc != 15)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
auto c_layout,
auto compute_type) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
......@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
using ComputeType = decltype(compute_type);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
......@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
CDataType,
ALayout,
BLayout,
CLayout>(
CLayout,
ComputeType>(
do_verification,
init_method,
do_log,
......@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
#if defined CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
}
#endif
else
......
......@@ -20,10 +20,11 @@ enum struct ConvLayout
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_F32_BF16, // 2
F16_F16_F16_BF8_F8 // 3
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_F32_BF16, // 2
F16_F16_F16_BF8_F8, // 3
I8_I8_I8 // 4
};
#define OP_NAME "grouped_conv_bwd_weight"
......@@ -35,7 +36,8 @@ static void print_helper_msg()
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight fp32, Output bf16\n"
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n"
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
<< " 4: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]\n"
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
......@@ -84,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using namespace ck::tensor_layout::convolution;
......@@ -139,83 +137,93 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_F32_BF16)
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
}
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
......
......@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32>(
ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32, F32, false>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16>(
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F32, false>(
do_verification, init_method, do_log, time_kernel, length);
}
else
......
......@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, rank>(
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, rank>(
ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else
......
......@@ -11,40 +11,36 @@ function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
foreach(source IN LISTS ARGN)
set(test 0)
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
set(test 1)
endif()
if(test EQUAL 1)
message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
if(test EQUAL 1)
message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
......@@ -70,38 +66,34 @@ function(add_gtest_executable TEST_NAME)
message("adding gtest ${TEST_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
foreach(source IN LISTS ARGN)
set(test 0)
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 1)
endif()
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
set(test 1)
endif()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
......
......@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance)
endif()
add_gtest_executable(test_batched_gemm test_batched_gemm.cpp)
target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 512;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
struct GemmParams
{
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t BatchCount;
};
class TestBatchedGemm : public ::testing::Test
{
protected:
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
std::vector<GemmParams> params;
template <typename DataType>
void Run()
{
using namespace ck::tensor_operation::device;
bool pass = true;
for(auto& param : params)
{
const auto M = param.M;
const auto N = param.N;
const auto K = param.K;
const auto BatchCount = param.BatchCount;
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
}
EXPECT_TRUE(pass);
}
};
#ifdef CK_ENABLE_INT8
TEST_F(TestBatchedGemm, i8)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<int8_t>();
}
#endif
#ifdef CK_ENABLE_BF16
TEST_F(TestBatchedGemm, bf16)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<ck::bhalf_t>();
}
#endif
#ifdef CK_ENABLE_FP16
TEST_F(TestBatchedGemm, fp16)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<ck::half_t>();
}
#endif
#ifdef CK_ENABLE_FP32
TEST_F(TestBatchedGemm, fp32)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<float>();
}
#endif
add_gtest_executable(test_contraction test_contraction.cpp)
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
set(target 1)
endif()
if(gpu IN_LIST gpu_list AND target EQUAL 0)
if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES)
add_gtest_executable(test_contraction test_contraction.cpp)
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
set(target 1)
endif()
endif()
endforeach()
add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp)
target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance)
add_gtest_executable(test_conv_tensor_rearrange_interface test_conv_tensor_rearrange_interface.cpp)
target_link_libraries(test_conv_tensor_rearrange_interface PRIVATE utility)
if (USE_BITINT_EXTENSION_INT4)
add_gtest_executable(test_int4 int4.cpp)
add_gtest_executable(test_int4 test_int4.cpp)
if(result EQUAL 0)
target_link_libraries(test_int4 PRIVATE utility)
endif()
endif()
add_gtest_executable(test_fp8 fp8.cpp)
add_gtest_executable(test_fp8 test_fp8.cpp)
if(result EQUAL 0)
target_link_libraries(test_fp8 PRIVATE utility)
endif()
add_gtest_executable(test_bf8 bf8.cpp)
add_gtest_executable(test_bf8 test_bf8.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility)
endif()
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
set(target 1)
endif()
endforeach()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment