Commit f30e5975 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 91b414cd bec84efb
...@@ -21,6 +21,8 @@ template <typename XDataType, ...@@ -21,6 +21,8 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
bool SaveMeanInvStd,
index_t Rank> index_t Rank>
bool profile_layernorm_impl(int do_verification, bool profile_layernorm_impl(int do_verification,
int init_method, int init_method,
...@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification, ...@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification,
Tensor<GammaDataType> gamma(reduce_length); Tensor<GammaDataType> gamma(reduce_length);
Tensor<BetaDataType> beta(reduce_length); Tensor<BetaDataType> beta(reduce_length);
Tensor<YDataType> y(length); Tensor<YDataType> y(length);
Tensor<SaveMeanInvStdDataType> save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> save_inv_std({length[0]});
Tensor<YDataType> host_y(length); Tensor<YDataType> host_y(length);
Tensor<SaveMeanInvStdDataType> host_save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> host_save_inv_std({length[0]});
std::vector<index_t> strideXY = std::vector<index_t> strideXY =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideGammaBeta = strideXY; std::vector<index_t> strideGammaBeta = strideXY;
strideGammaBeta[0] = 0; strideGammaBeta[0] = 0;
std::vector<index_t> strideSaveMeanInvStd = {1};
switch(init_method) switch(init_method)
{ {
case 0: case 0:
...@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification, ...@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification,
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification, ...@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification, ...@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType, using ReferenceInstance =
GammaDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
BetaDataType, GammaDataType,
YDataType, BetaDataType,
ComputeDataType, YDataType,
PassThrough, SaveMeanInvStdDataType,
Rank, ComputeDataType,
NumReduceDim>; PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, reduce_dim, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
length,
reduce_dim,
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; int num_kernel = 0;
auto f_get_argument = [&](auto& inst_ptr) {
if constexpr(SaveMeanInvStd)
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
PassThrough{});
else
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
};
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
auto argument_ptr = inst_ptr->MakeArgumentPointer(length, auto argument_ptr = f_get_argument(inst_ptr);
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
if(inst_ptr->IsSupportedArgument(argument_ptr.get())) if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification, ...@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification,
beta.mDesc.GetElementSize() * sizeof(BetaDataType) + beta.mDesc.GetElementSize() * sizeof(BetaDataType) +
y.mDesc.GetElementSize() * sizeof(YDataType); y.mDesc.GetElementSize() * sizeof(YDataType);
if constexpr(SaveMeanInvStd)
num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) +
save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel) if(time_kernel)
...@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification, ...@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if constexpr(SaveMeanInvStd)
{
save_mean_dev.FromDevice(save_mean.mData.data());
pass &= ck::utils::check_err(
save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3);
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(save_inv_std.mData,
host_save_inv_std.mData,
"Error: Incorrect results",
1e-3,
1e-3);
}
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl;
......
...@@ -25,8 +25,6 @@ set(PROFILER_SOURCES ...@@ -25,8 +25,6 @@ set(PROFILER_SOURCES
profile_batchnorm_fwd.cpp profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp profile_batchnorm_infer.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
profile_grouped_conv_bwd_data.cpp profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp profile_conv_tensor_rearrange.cpp
) )
...@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) ...@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif() endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
...@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan ...@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
...@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d ...@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif()
if(DL_KERNELS) if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
......
...@@ -25,6 +25,7 @@ enum struct GemmDataType ...@@ -25,6 +25,7 @@ enum struct GemmDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
F8_F16_F16, // 4 F8_F16_F16, // 4
F16_F8_F16, // 5 F16_F8_F16, // 5
F16_F16_F16_F8, // 6
}; };
#define OP_NAME "gemm_splitk" #define OP_NAME "gemm_splitk"
...@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -35,7 +36,8 @@ int profile_gemm_splitk(int argc, char* argv[])
if(argc != 15) if(argc != 15)
{ {
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
...@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -80,7 +82,8 @@ int profile_gemm_splitk(int argc, char* argv[])
auto c_type, auto c_type,
auto a_layout, auto a_layout,
auto b_layout, auto b_layout,
auto c_layout) { auto c_layout,
auto compute_type) {
using ADataType = decltype(a_type); using ADataType = decltype(a_type);
using BDataType = decltype(b_type); using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type); using AccDataType = decltype(acc_type);
...@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -90,6 +93,8 @@ int profile_gemm_splitk(int argc, char* argv[])
using BLayout = decltype(b_layout); using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout); using CLayout = decltype(c_layout);
using ComputeType = decltype(compute_type);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M; const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
...@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -100,7 +105,8 @@ int profile_gemm_splitk(int argc, char* argv[])
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>( CLayout,
ComputeType>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
...@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -118,68 +124,84 @@ int profile_gemm_splitk(int argc, char* argv[])
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
} }
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}); return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}); return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
}
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
} }
#endif #endif
else else
......
...@@ -20,10 +20,11 @@ enum struct ConvLayout ...@@ -20,10 +20,11 @@ enum struct ConvLayout
enum struct ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_F32_BF16, // 2 BF16_F32_BF16, // 2
F16_F16_F16_BF8_F8 // 3 F16_F16_F16_BF8_F8, // 3
I8_I8_I8 // 4
}; };
#define OP_NAME "grouped_conv_bwd_weight" #define OP_NAME "grouped_conv_bwd_weight"
...@@ -35,7 +36,8 @@ static void print_helper_msg() ...@@ -35,7 +36,8 @@ static void print_helper_msg()
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n" << " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight fp32, Output bf16\n" << " 2: Input bf16, Weight fp32, Output bf16\n"
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n" << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
<< " 4: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]\n" "N, K, Ho, Wo]\n"
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
...@@ -84,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -84,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
#ifdef CK_ENABLE_FP8 using F8 = ck::f8_t;
using F8 = ck::f8_t; using BF8 = ck::bf8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using namespace ck::tensor_layout::convolution; using namespace ck::tensor_layout::convolution;
...@@ -139,83 +137,93 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -139,83 +137,93 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::F16_F16_F16_BF8_F8) if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
} }
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
} }
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[]) ...@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Float) if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32>( ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32, F32, false>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Half) else if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16>( ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F32, false>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else else
......
...@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[]) ...@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half) if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, rank>( ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, rank>( ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else else
......
...@@ -11,40 +11,36 @@ function(add_test_executable TEST_NAME) ...@@ -11,40 +11,36 @@ function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
set(result 1) set(result 1)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
if(type MATCHES "fp16") set(test 1)
set(type1 "_f16") endif()
elseif(type MATCHES "fp32") if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(type1 "_f32") set(test 1)
elseif(type MATCHES "fp8") endif()
set(type1 "_f8") if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
elseif(type MATCHES "bf16") set(test 1)
set(type1 "_b16") endif()
elseif(type MATCHES "fp64") if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(type1 "_f64") set(test 1)
elseif(type MATCHES "int8") endif()
set(type1 "_i8") if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
endif() set(test 1)
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") endif()
#if filename matches any selected type, exit type loop and do no exclude the file from the list if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 0) set(test 1)
break() endif()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND set(test 1)
NOT(source MATCHES type OR source MATCHES type1)) endif()
#if filename contains a type which doesn't match any selected type, mark it for removal if(test EQUAL 1)
set(test 1) message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
if(test EQUAL 1) endif()
message("removing test ${source} ") foreach(source IN LISTS ARGN)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ") message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
...@@ -70,38 +66,34 @@ function(add_gtest_executable TEST_NAME) ...@@ -70,38 +66,34 @@ function(add_gtest_executable TEST_NAME)
message("adding gtest ${TEST_NAME}") message("adding gtest ${TEST_NAME}")
set(result 1) set(result 1)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
if(type MATCHES "fp16") set(test 1)
set(type1 "_f16") endif()
elseif(type MATCHES "fp32") if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
set(type1 "_f32") set(test 1)
elseif(type MATCHES "fp8") endif()
set(type1 "_f8") if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
elseif(type MATCHES "bf16") set(test 1)
set(type1 "_b16") endif()
elseif(type MATCHES "fp64") if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
set(type1 "_f64") set(test 1)
elseif(type MATCHES "int8") endif()
set(type1 "_i8") if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
endif() set(test 1)
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") endif()
#if filename matches any selected type, exit type loop and do no exclude the file from the list if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
set(test 0) set(test 1)
break() endif()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND set(test 1)
NOT(source MATCHES type OR source MATCHES type1)) endif()
#if filename contains a type which doesn't match any selected type, mark it for removal if(test EQUAL 1)
set(test 1) message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif() endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
......
...@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,22 +2,8 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) add_gtest_executable(test_batched_gemm test_batched_gemm.cpp)
if(result EQUAL 0) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance)
endif()
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 512;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
namespace {
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
int main()
{
int M = 256;
int N = 256;
int K = 128;
int BatchCount = 3;
bool pass = true;
using namespace ck::tensor_operation::device;
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
BDataType,
CDataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
struct GemmParams
{
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t BatchCount;
};
class TestBatchedGemm : public ::testing::Test
{
protected:
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
std::vector<GemmParams> params;
template <typename DataType>
void Run()
{
using namespace ck::tensor_operation::device;
bool pass = true;
for(auto& param : params)
{
const auto M = param.M;
const auto N = param.N;
const auto K = param.K;
const auto BatchCount = param.BatchCount;
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Row,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Row,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Row,
Col,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Row,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass =
pass && ck::profiler::profile_batched_gemm_impl<DataType,
DataType,
DataType,
Col,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
DeviceBatchedGemm<Col,
Col,
Row,
DataType,
DataType,
DataType,
PassThrough,
PassThrough,
PassThrough>>(
true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
}
EXPECT_TRUE(pass);
}
};
#ifdef CK_ENABLE_INT8
TEST_F(TestBatchedGemm, i8)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<int8_t>();
}
#endif
#ifdef CK_ENABLE_BF16
TEST_F(TestBatchedGemm, bf16)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<ck::bhalf_t>();
}
#endif
#ifdef CK_ENABLE_FP16
TEST_F(TestBatchedGemm, fp16)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<ck::half_t>();
}
#endif
#ifdef CK_ENABLE_FP32
TEST_F(TestBatchedGemm, fp32)
{
this->params.push_back({64, 64, 64, 2});
this->params.push_back({64, 64, 64, 1});
this->params.push_back({60, 60, 60, 2});
this->params.push_back({68, 68, 68, 2});
this->params.push_back({40, 40, 40, 2});
this->params.push_back({256, 256, 128, 3});
this->template Run<float>();
}
#endif
add_gtest_executable(test_contraction test_contraction.cpp)
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) add_gtest_executable(test_contraction test_contraction.cpp)
set(target 1) target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
endif() add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
set(target 1)
endif()
endif()
endforeach() endforeach()
add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp) add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp)
target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance) target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance)
add_gtest_executable(test_conv_tensor_rearrange_interface test_conv_tensor_rearrange_interface.cpp) add_gtest_executable(test_conv_tensor_rearrange_interface test_conv_tensor_rearrange_interface.cpp)
target_link_libraries(test_conv_tensor_rearrange_interface PRIVATE utility) target_link_libraries(test_conv_tensor_rearrange_interface PRIVATE utility)
if (USE_BITINT_EXTENSION_INT4) if (USE_BITINT_EXTENSION_INT4)
add_gtest_executable(test_int4 int4.cpp) add_gtest_executable(test_int4 test_int4.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_int4 PRIVATE utility) target_link_libraries(test_int4 PRIVATE utility)
endif() endif()
endif() endif()
add_gtest_executable(test_fp8 fp8.cpp) add_gtest_executable(test_fp8 test_fp8.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_fp8 PRIVATE utility) target_link_libraries(test_fp8 PRIVATE utility)
endif() endif()
add_gtest_executable(test_bf8 bf8.cpp) add_gtest_executable(test_bf8 test_bf8.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility) target_link_libraries(test_bf8 PRIVATE utility)
endif() endif()
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment