"sims/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "3efd1b1532fb393d4ac84a2c87e172749b524b26"
Commit b2ba0a69 authored by Jing Zhang's avatar Jing Zhang
Browse files

Merge remote-tracking branch 'origin/develop' into grouped_gemm_dev_args

parents 48d356fc 49180fd6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f16_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_f32_instances<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
set(CONV2D_PERLAYER_QUANT_SRC set(CONV2D_PERLAYER_QUANT_SRC
conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp
conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
...@@ -36,3 +37,4 @@ add_instance_library(device_quantization_instance ...@@ -36,3 +37,4 @@ add_instance_library(device_quantization_instance
${CONV2D_BIAS_PERCHANNEL_QUANT_SRC} ${CONV2D_BIAS_PERCHANNEL_QUANT_SRC}
${GEMM_QUANT_SRC} ${GEMM_QUANT_SRC}
) )
endif()
\ No newline at end of file
...@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in ...@@ -81,4 +81,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_in
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
...@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[]) ...@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
const int BatchCount = std::stoi(argv[17]); const int BatchCount = std::stoi(argv[17]);
using F16 = ck::half_t; using F16 = ck::half_t;
#ifdef __int8__
using INT8 = int8_t; using INT8 = int8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[]) ...@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
} }
#ifdef __int8__
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
...@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[]) ...@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
{ {
return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
#ifdef __int8__
using INT8 = int8_t; using INT8 = int8_t;
#endif
using NWC = ck::tensor_layout::convolution::NWC; using NWC = ck::tensor_layout::convolution::NWC;
using NHWC = ck::tensor_layout::convolution::NHWC; using NHWC = ck::tensor_layout::convolution::NHWC;
...@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{});
} }
#ifdef __int8__
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{}); return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{});
} }
#endif
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
{ {
...@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{});
} }
#ifdef __int8__
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{}); return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{});
} }
#endif
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
{ {
...@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{});
} }
#ifdef __int8__
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{}); return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{});
} }
#endif
} }
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[]) ...@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; #ifdef __bf16__
using BF16 = ck::bhalf_t;
#endif
#ifdef __int8__
using INT8 = int8_t; using INT8 = int8_t;
using INT32 = int32_t; using INT32 = int32_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
} }
#ifdef __bf16__
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
...@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
} }
#endif
#ifdef __int8__
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
...@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[]) ...@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -77,15 +77,10 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) ...@@ -77,15 +77,10 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using GNHWC = ck::tensor_layout::convolution::GNHWC; using namespace ck::tensor_layout::convolution;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
constexpr auto I2 = ck::Number<2>{}; constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp, auto profile = [&](auto num_dim_spatial_tmp,
auto out_layout, auto out_layout,
...@@ -116,36 +111,70 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) ...@@ -116,36 +111,70 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
// GNHWC_GKYXC_GNHWK if(num_dim_spatial == 2)
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{});
}
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{});
}
} }
} }
// NHWGC_GKYXC_NHWGK else if(num_dim_spatial == 3)
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{});
}
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}); if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{});
}
} }
} }
......
...@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>, ...@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
} // namespace } // namespace
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes); TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
#ifdef __fp16
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); } TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
#endif
#ifdef __int8__
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); } TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
#endif
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_fp32 gemm_fp32.cpp) add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
target_link_libraries(test_gemm_fp32 PRIVATE utility) target_link_libraries(test_gemm_fp32 PRIVATE utility)
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_fp16 gemm_fp16.cpp) add_test_executable(test_gemm_fp16 gemm_fp16.cpp)
target_link_libraries(test_gemm_fp16 PRIVATE utility) target_link_libraries(test_gemm_fp16 PRIVATE utility)
target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
target_link_libraries(test_gemm_bf16 PRIVATE utility)
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
add_library(gemm_standalone_xdl_fp16_instances STATIC add_library(gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp instance/gemm_f16_nt_instance.cpp
...@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC ...@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
target_link_libraries(test_gemm_bf16 PRIVATE utility)
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
endif()
\ No newline at end of file
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp)
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp)
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
endif() endif()
\ No newline at end of file
...@@ -46,23 +46,36 @@ class TestGroupedConvndBwdData : public ::testing::Test ...@@ -46,23 +46,36 @@ class TestGroupedConvndBwdData : public ::testing::Test
} }
}; };
using GNHWC = ck::tensor_layout::convolution::GNHWC; using namespace ck::tensor_layout::convolution;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GKYXC = ck::tensor_layout::convolution::GKYXC; using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
std::tuple<float, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
using GNHWK = ck::tensor_layout::convolution::GNHWK; using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
using NHWGK = ck::tensor_layout::convolution::NHWGK; std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
using KernelTypes = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>, template <typename Tuple>
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>, class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple>
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>, {
std::tuple<float, NHWGK, GKYXC, NHWGC>, };
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>; template <typename Tuple>
TYPED_TEST_SUITE(TestGroupedConvndBwdData, KernelTypes); class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple>
{
};
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d);
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdData, Test2D) TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
{ {
this->conv_params.clear(); this->conv_params.clear();
...@@ -76,3 +89,15 @@ TYPED_TEST(TestGroupedConvndBwdData, Test2D) ...@@ -76,3 +89,15 @@ TYPED_TEST(TestGroupedConvndBwdData, Test2D)
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->template Run<2>(); this->template Run<2>();
} }
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
{
this->conv_params.clear();
this->conv_params.push_back(
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->template Run<3>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment