Commit 1b5af83d authored by illsilin's avatar illsilin
Browse files

Merge branch 'develop' into lwpck-976

parents aac26d32 1fd27d52
...@@ -21,6 +21,8 @@ template <typename XDataType, ...@@ -21,6 +21,8 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
bool SaveMeanInvStd,
index_t Rank> index_t Rank>
bool profile_layernorm_impl(int do_verification, bool profile_layernorm_impl(int do_verification,
int init_method, int init_method,
...@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification, ...@@ -43,13 +45,19 @@ bool profile_layernorm_impl(int do_verification,
Tensor<GammaDataType> gamma(reduce_length); Tensor<GammaDataType> gamma(reduce_length);
Tensor<BetaDataType> beta(reduce_length); Tensor<BetaDataType> beta(reduce_length);
Tensor<YDataType> y(length); Tensor<YDataType> y(length);
Tensor<SaveMeanInvStdDataType> save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> save_inv_std({length[0]});
Tensor<YDataType> host_y(length); Tensor<YDataType> host_y(length);
Tensor<SaveMeanInvStdDataType> host_save_mean({length[0]});
Tensor<SaveMeanInvStdDataType> host_save_inv_std({length[0]});
std::vector<index_t> strideXY = std::vector<index_t> strideXY =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideGammaBeta = strideXY; std::vector<index_t> strideGammaBeta = strideXY;
strideGammaBeta[0] = 0; strideGammaBeta[0] = 0;
std::vector<index_t> strideSaveMeanInvStd = {1};
switch(init_method) switch(init_method)
{ {
case 0: case 0:
...@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification, ...@@ -75,6 +83,9 @@ bool profile_layernorm_impl(int do_verification,
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification, ...@@ -86,8 +97,8 @@ bool profile_layernorm_impl(int do_verification,
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification, ...@@ -105,40 +116,74 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType, using ReferenceInstance =
GammaDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
BetaDataType, GammaDataType,
YDataType, BetaDataType,
ComputeDataType, YDataType,
PassThrough, SaveMeanInvStdDataType,
Rank, ComputeDataType,
NumReduceDim>; PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, length, reduce_dim, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
length,
reduce_dim,
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
int num_kernel = 0; int num_kernel = 0;
auto f_get_argument = [&](auto& inst_ptr) {
if constexpr(SaveMeanInvStd)
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
PassThrough{});
else
return inst_ptr->MakeArgumentPointer(length,
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
strideSaveMeanInvStd,
strideSaveMeanInvStd,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
};
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
auto argument_ptr = inst_ptr->MakeArgumentPointer(length, auto argument_ptr = f_get_argument(inst_ptr);
strideXY,
strideGammaBeta,
strideGammaBeta,
strideXY,
reduce_dim,
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
nullptr,
nullptr,
PassThrough{});
if(inst_ptr->IsSupportedArgument(argument_ptr.get())) if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
...@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification, ...@@ -168,6 +213,10 @@ bool profile_layernorm_impl(int do_verification,
beta.mDesc.GetElementSize() * sizeof(BetaDataType) + beta.mDesc.GetElementSize() * sizeof(BetaDataType) +
y.mDesc.GetElementSize() * sizeof(YDataType); y.mDesc.GetElementSize() * sizeof(YDataType);
if constexpr(SaveMeanInvStd)
num_bytes += save_mean.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType) +
save_inv_std.mDesc.GetElementSpaceSize() * sizeof(SaveMeanInvStdDataType);
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel) if(time_kernel)
...@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification, ...@@ -184,10 +233,23 @@ bool profile_layernorm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if constexpr(SaveMeanInvStd)
{
save_mean_dev.FromDevice(save_mean.mData.data());
pass &= ck::utils::check_err(
save_mean.mData, host_save_mean.mData, "Error: Incorrect results", 1e-3, 1e-3);
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(save_inv_std.mData,
host_save_inv_std.mData,
"Error: Incorrect results",
1e-3,
1e-3);
}
if(do_log) if(do_log)
{ {
LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "x : ", x.mData, ",") << std::endl;
......
...@@ -25,8 +25,6 @@ set(PROFILER_SOURCES ...@@ -25,8 +25,6 @@ set(PROFILER_SOURCES
profile_batchnorm_fwd.cpp profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp profile_batchnorm_infer.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
profile_grouped_conv_bwd_data.cpp profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp profile_conv_tensor_rearrange.cpp
) )
...@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) ...@@ -46,6 +44,11 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif() endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
...@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan ...@@ -76,8 +79,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
...@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d ...@@ -85,9 +86,18 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif()
if(DL_KERNELS) if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif() endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
......
...@@ -20,10 +20,11 @@ enum struct ConvLayout ...@@ -20,10 +20,11 @@ enum struct ConvLayout
enum struct ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_F32_BF16, // 2 BF16_F32_BF16, // 2
F16_F16_F16_BF8_F8 // 3 F16_F16_F16_BF8_F8, // 3
I8_I8_I8 // 4
}; };
#define OP_NAME "grouped_conv_bwd_weight" #define OP_NAME "grouped_conv_bwd_weight"
...@@ -35,7 +36,8 @@ static void print_helper_msg() ...@@ -35,7 +36,8 @@ static void print_helper_msg()
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n" << " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight fp32, Output bf16\n" << " 2: Input bf16, Weight fp32, Output bf16\n"
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n" << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
<< " 4: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]\n" "N, K, Ho, Wo]\n"
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
...@@ -84,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -84,12 +86,8 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
#ifdef CK_ENABLE_FP8 using F8 = ck::f8_t;
using F8 = ck::f8_t; using BF8 = ck::bf8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using namespace ck::tensor_layout::convolution; using namespace ck::tensor_layout::convolution;
...@@ -139,83 +137,93 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -139,83 +137,93 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_F32_BF16) if(data_type == ConvDataType::BF16_F32_BF16)
{ {
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::F16_F16_F16_BF8_F8) if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
} }
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
} }
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[]) ...@@ -93,12 +93,12 @@ int profile_groupnorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Float) if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32>( ck::profiler::profile_groupnorm_impl<F32, F32, F32, F32, F32, F32, false>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Half) else if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16>( ck::profiler::profile_groupnorm_impl<F16, F16, F16, F32, F16, F32, false>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else else
......
...@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[]) ...@@ -82,12 +82,12 @@ int profile_layernorm(int argc, char* argv[])
if(data_type == ck::DataTypeEnum::Half) if(data_type == ck::DataTypeEnum::Half)
{ {
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, rank>( ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else if(data_type == ck::DataTypeEnum::Float) else if(data_type == ck::DataTypeEnum::Float)
{ {
ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, rank>( ck::profiler::profile_layernorm_impl<F32, F32, F32, F32, F32, F32, false, rank>(
do_verification, init_method, do_log, time_kernel, length); do_verification, init_method, do_log, time_kernel, length);
} }
else else
......
...@@ -11,40 +11,40 @@ function(add_test_executable TEST_NAME) ...@@ -11,40 +11,40 @@ function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
set(result 1) set(result 1)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16") if(type MATCHES "fp16")
set(type1 "_f16") set(type1 "_f16")
elseif(type MATCHES "fp32") elseif(type MATCHES "fp32")
set(type1 "_f32") set(type1 "_f32")
elseif(type MATCHES "fp8") elseif(type MATCHES "fp8")
set(type1 "_f8") set(type1 "_f8")
elseif(type MATCHES "bf16") elseif(type MATCHES "bf16")
set(type1 "_b16") set(type1 "_b16")
elseif(type MATCHES "fp64") elseif(type MATCHES "fp64")
set(type1 "_f64") set(type1 "_f64")
elseif(type MATCHES "int8") elseif(type MATCHES "int8")
set(type1 "_i8") set(type1 "_i8")
endif() endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list #if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0) set(test 0)
break() break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1)) NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal #if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1) set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
if(test EQUAL 1) endif()
message("removing test ${source} ") foreach(source IN LISTS ARGN)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ") message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
...@@ -70,38 +70,38 @@ function(add_gtest_executable TEST_NAME) ...@@ -70,38 +70,38 @@ function(add_gtest_executable TEST_NAME)
message("adding gtest ${TEST_NAME}") message("adding gtest ${TEST_NAME}")
set(result 1) set(result 1)
if(DEFINED DTYPES) if(DEFINED DTYPES)
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
set(test 0) set(test 0)
foreach(type IN LISTS DTYPES) foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16") if(type MATCHES "fp16")
set(type1 "_f16") set(type1 "_f16")
elseif(type MATCHES "fp32") elseif(type MATCHES "fp32")
set(type1 "_f32") set(type1 "_f32")
elseif(type MATCHES "fp8") elseif(type MATCHES "fp8")
set(type1 "_f8") set(type1 "_f8")
elseif(type MATCHES "bf16") elseif(type MATCHES "bf16")
set(type1 "_b16") set(type1 "_b16")
elseif(type MATCHES "fp64") elseif(type MATCHES "fp64")
set(type1 "_f64") set(type1 "_f64")
elseif(type MATCHES "int8") elseif(type MATCHES "int8")
set(type1 "_i8") set(type1 "_i8")
endif() endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list #if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0) set(test 0)
break() break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1)) NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal #if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1) set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif() endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
......
add_gtest_executable(test_contraction test_contraction.cpp)
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) add_gtest_executable(test_contraction test_contraction.cpp)
set(target 1) target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
endif() add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
set(target 1)
endif()
endif()
endforeach() endforeach()
add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp) add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp)
target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance) target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance)
add_gtest_executable(test_conv_tensor_rearrange_interface test_conv_tensor_rearrange_interface.cpp) add_gtest_executable(test_conv_tensor_rearrange_interface test_conv_tensor_rearrange_interface.cpp)
target_link_libraries(test_conv_tensor_rearrange_interface PRIVATE utility) target_link_libraries(test_conv_tensor_rearrange_interface PRIVATE utility)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp) add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
set(target 1)
endif()
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" #include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
...@@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k)
{ {
// Odd K or C values are supported only by DL kernel (only applies to fp16) // Odd K or C values are supported only by DL and WMMA
// DL kernel currently supports only `split_k=1` // kernels (only applies to fp16)
// DL and WMMA kernels currently support only `split_k=1`
if constexpr(std::is_same_v<InDataType, ck::half_t>) if constexpr(std::is_same_v<InDataType, ck::half_t>)
{ {
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
...@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
} }
} }
const bool is_navi3x = ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102";
if(is_navi3x)
{
// on navi3x only support for 3d is implemented
if constexpr(NDimSpatial{} != 3)
{
return true;
}
// on navi3x only support for i8 and fp16 is implemented
if constexpr(!((std::is_same_v<InDataType, int8_t> &&
std::is_same_v<WeiDataType, int8_t> &&
std::is_same_v<OutDataType, int8_t>) ||
(std::is_same_v<InDataType, ck::half_t> &&
std::is_same_v<WeiDataType, ck::half_t> &&
std::is_same_v<OutDataType, ck::half_t>)))
{
return true;
}
// WMMA kernel is only supported for split_k=1
if(split_k != 1)
{
return true;
}
}
else
{
// support for i8 is only implemented on navi3x
if constexpr(std::is_same_v<InDataType, int8_t> &&
std::is_same_v<WeiDataType, int8_t> && std::is_same_v<OutDataType, int8_t>)
{
return true;
}
}
return false; return false;
} }
...@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types< ...@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<int8_t, int8_t, int8_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>, std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include <gtest/gtest.h>
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using ConvolutionBackwardWeightSpecialization =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto Filter1x1Stride1Pad0 =
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
template <typename Tuple, ConvolutionBackwardWeightSpecialization ConvSpec>
class TestGroupedConvndBwdWeight : public ::testing::Test
{
protected:
using OutLayout = std::tuple_element_t<0, Tuple>;
using WeiLayout = std::tuple_element_t<1, Tuple>;
using InLayout = std::tuple_element_t<2, Tuple>;
static constexpr ck::index_t NDimSpatial = std::tuple_element_t<3, Tuple>{};
// clang-format off
using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle
//| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
//| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<NDimSpatial, InLayout, WeiLayout, OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
ck::utils::conv::ConvParam conv_param;
template <ck::index_t SplitK>
bool Run()
{
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
auto conv = GroupedConvBwdWeightDeviceInstance{};
auto argument = conv.MakeArgument(nullptr,
nullptr,
nullptr,
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{},
SplitK);
return conv.IsSupportedArgument(argument);
}
};
using namespace ck::tensor_layout::convolution;
using KernelTypes3d = ::testing::Types<std::tuple<GNDHWK, GKZYXC, GNDHWC, ck::Number<3>>,
std::tuple<NDHWGK, GKZYXC, NDHWGC, ck::Number<3>>>;
template <typename Tuple>
class TestGroupedConvndBwdWeightFilter1x13d
: public TestGroupedConvndBwdWeight<Tuple, Filter1x1Stride1Pad0>
{
};
template <typename Tuple>
class TestGroupedConvndBwdWeightDefault3d
: public TestGroupedConvndBwdWeight<Tuple, ConvBwdWeightDefault>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightFilter1x13d, KernelTypes3d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck)
{
// Check filter 3x3x3 instead of 1x1x1
this->conv_param = {
3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
bool is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported);
// Check strides 2x2x2 instead of 1x1x1
this->conv_param = {
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported);
// Check with pad
this->conv_param = {
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported);
// Supported version
this->conv_param = {
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
is_supported = this->template Run<1>();
EXPECT_TRUE(is_supported);
}
TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, VectorLoadCheck)
{
// vector load for A
this->conv_param = {
3, 2, 128, 129, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
bool is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported);
// vector load for B, E, Ds
this->conv_param = {
3, 2, 128, 128, 257, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
is_supported = this->template Run<1>();
EXPECT_FALSE(is_supported);
}
TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, SplitKCheck)
{
// SplitK=1
this->conv_param = {
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
bool is_supported = this->template Run<1>();
EXPECT_TRUE(is_supported);
// SplitK=2
this->conv_param = {
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
is_supported = this->template Run<2>();
EXPECT_FALSE(is_supported);
}
...@@ -12,11 +12,12 @@ template <typename Tuple> ...@@ -12,11 +12,12 @@ template <typename Tuple>
class TestGroupnorm : public ::testing::Test class TestGroupnorm : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>;
void Run() void Run()
{ {
...@@ -37,7 +38,9 @@ class TestGroupnorm : public ::testing::Test ...@@ -37,7 +38,9 @@ class TestGroupnorm : public ::testing::Test
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType>(true, 2, false, false, length); YDataType,
SaveMeanInvStdDataType,
true>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
} }
...@@ -45,7 +48,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -45,7 +48,7 @@ class TestGroupnorm : public ::testing::Test
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>>; std::tuple<F16, F16, F16, F32, F16, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); } TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); }
...@@ -12,11 +12,12 @@ template <typename Tuple> ...@@ -12,11 +12,12 @@ template <typename Tuple>
class TestGroupnorm : public ::testing::Test class TestGroupnorm : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>;
void Run() void Run()
{ {
...@@ -35,7 +36,9 @@ class TestGroupnorm : public ::testing::Test ...@@ -35,7 +36,9 @@ class TestGroupnorm : public ::testing::Test
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType>(true, 2, false, false, length); YDataType,
SaveMeanInvStdDataType,
true>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
} }
...@@ -43,7 +46,7 @@ class TestGroupnorm : public ::testing::Test ...@@ -43,7 +46,7 @@ class TestGroupnorm : public ::testing::Test
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST_SUITE(TestGroupnorm, KernelTypes);
TYPED_TEST(TestGroupnorm, Test_FP32) { this->Run(); } TYPED_TEST(TestGroupnorm, Test_FP32) { this->Run(); }
...@@ -12,11 +12,12 @@ template <typename Tuple> ...@@ -12,11 +12,12 @@ template <typename Tuple>
class TestLayernorm2d : public ::testing::Test class TestLayernorm2d : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>;
void Run() void Run()
{ {
...@@ -31,6 +32,8 @@ class TestLayernorm2d : public ::testing::Test ...@@ -31,6 +32,8 @@ class TestLayernorm2d : public ::testing::Test
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
true,
2>(true, 2, false, false, length); 2>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
...@@ -39,7 +42,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -39,7 +42,7 @@ class TestLayernorm2d : public ::testing::Test
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F16, F16, F16, F32, F16>>; std::tuple<F16, F16, F16, F32, F16, F32>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); } TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); }
...@@ -12,11 +12,12 @@ template <typename Tuple> ...@@ -12,11 +12,12 @@ template <typename Tuple>
class TestLayernorm2d : public ::testing::Test class TestLayernorm2d : public ::testing::Test
{ {
protected: protected:
using XDataType = std::tuple_element_t<0, Tuple>; using XDataType = std::tuple_element_t<0, Tuple>;
using GammaDataType = std::tuple_element_t<1, Tuple>; using GammaDataType = std::tuple_element_t<1, Tuple>;
using BetaDataType = std::tuple_element_t<2, Tuple>; using BetaDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>; using ComputeDataType = std::tuple_element_t<3, Tuple>;
using YDataType = std::tuple_element_t<4, Tuple>; using YDataType = std::tuple_element_t<4, Tuple>;
using SaveMeanInvStdDataType = std::tuple_element_t<5, Tuple>;
void Run() void Run()
{ {
...@@ -31,6 +32,8 @@ class TestLayernorm2d : public ::testing::Test ...@@ -31,6 +32,8 @@ class TestLayernorm2d : public ::testing::Test
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
true,
2>(true, 2, false, false, length); 2>(true, 2, false, false, length);
EXPECT_TRUE(success); EXPECT_TRUE(success);
} }
...@@ -39,7 +42,7 @@ class TestLayernorm2d : public ::testing::Test ...@@ -39,7 +42,7 @@ class TestLayernorm2d : public ::testing::Test
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType>
std::tuple<F32, F32, F32, F32, F32>>; std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes);
TYPED_TEST(TestLayernorm2d, Test_FP32) { this->Run(); } TYPED_TEST(TestLayernorm2d, Test_FP32) { this->Run(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment