Unverified Commit 146972f4 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

fix bug in gemm profiler (#344)

parent 75ab874e
...@@ -26,6 +26,7 @@ static constexpr auto ConvSpec = ...@@ -26,6 +26,7 @@ static constexpr auto ConvSpec =
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
#if 1
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -78,6 +79,60 @@ using DeviceGroupedConvNDFwdInstance = ...@@ -78,6 +79,60 @@ using DeviceGroupedConvNDFwdInstance =
1, 1,
S<1, 32, 1, 8>, S<1, 32, 1, 8>,
8>; 8>;
#else
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename BiasLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
256, // MPerBlock
16, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
1, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
2, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
4, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 256, 1, 1>,
1>;
#endif
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -131,11 +131,11 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( ...@@ -131,11 +131,11 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/KYXGC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC, NHWGC,
KYXGC, GKYXC,
Empty_Tuple, Empty_Tuple,
NHWGK, NHWGK,
F16, F16,
...@@ -292,7 +292,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -292,7 +292,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> && else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, KYXGC> && is_same_v<OutLayout, NHWGK>) is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
...@@ -302,7 +302,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -302,7 +302,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
......
...@@ -5,8 +5,8 @@ set(DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE ...@@ -5,8 +5,8 @@ set(DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp; device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp; device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp; device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp;
# NHWGC, KYXGC, NHWGK # NHWGC, GKYXC, NHWGK
device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instance.cpp; device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp;
) )
add_library(device_grouped_conv2d_fwd_instance OBJECT ${DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE}) add_library(device_grouped_conv2d_fwd_instance OBJECT ${DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE})
......
...@@ -72,34 +72,34 @@ int profile_gemm(int argc, char* argv[]) ...@@ -72,34 +72,34 @@ int profile_gemm(int argc, char* argv[])
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type, auto profile = [&](auto a_layout,
auto b_layout,
auto c_layout,
auto a_type,
auto b_type, auto b_type,
auto acc_type, auto acc_type,
auto c_type, auto c_type) {
auto a_layout, using ALayout = decltype(a_layout);
auto b_layout, using BLayout = decltype(b_layout);
auto c_layout) { using CLayout = decltype(c_layout);
using ADataType = decltype(a_type); using ADataType = decltype(a_type);
using BDataType = decltype(b_type); using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type); using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type); using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M; const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = bool pass =
ck::profiler::profile_gemm_impl<ADataType, ck::profiler::profile_gemm_impl<ALayout,
BLayout,
CLayout,
ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType>(do_verification,
ALayout,
BLayout,
CLayout>(do_verification,
init_method, init_method,
do_log, do_log,
time_kernel, time_kernel,
......
...@@ -13,7 +13,7 @@ namespace { ...@@ -13,7 +13,7 @@ namespace {
enum struct ConvLayout enum struct ConvLayout
{ {
GNHWC_GKYXC_GNHWK, // 0 GNHWC_GKYXC_GNHWK, // 0
NHWGC_KYXGC_NHWGK, // 1 NHWGC_GKYXC_NHWGK, // 1
}; };
enum struct ConvDataType enum struct ConvDataType
...@@ -34,7 +34,7 @@ static void print_helper_msg() ...@@ -34,7 +34,7 @@ static void print_helper_msg()
<< " 2: Input bf16, Weight bf16, Output bf16\n" << " 2: Input bf16, Weight bf16, Output bf16\n"
<< " 3: Input int8, Weight int8, Output int8)\n" << " 3: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[K, Y, X, G, C], Output[N, Ho, Wo, G, K])\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg4: verification (0: no, 1: yes)\n" << "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
...@@ -94,10 +94,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -94,10 +94,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using NHWGC = ck::tensor_layout::convolution::NHWGC; using NHWGC = ck::tensor_layout::convolution::NHWGC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
using KXGC = ck::tensor_layout::convolution::KXGC;
using KYXGC = ck::tensor_layout::convolution::KYXGC;
using KZYXGC = ck::tensor_layout::convolution::KZYXGC;
using NWGK = ck::tensor_layout::convolution::NWGK; using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK; using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
...@@ -193,62 +189,62 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -193,62 +189,62 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{});
} }
} }
// NHWGC_KYXGC_NHWGK // NHWGC_GKYXC_NHWGK
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_KYXGC_NHWGK) else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I1, NWGC{}, KXGC{}, NWGK{}, F32{}, F32{}, F32{}); return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I1, NWGC{}, KXGC{}, NWGK{}, F16{}, F16{}, F16{}); return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(data_type == ConvDataType::BF16_BF16_BF16)
{ {
return profile(I1, NWGC{}, KXGC{}, NWGK{}, BF16{}, BF16{}, BF16{}); return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I1, NWGC{}, KXGC{}, NWGK{}, INT8{}, INT8{}, INT8{}); return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_KYXGC_NHWGK) else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, F32{}, F32{}, F32{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, F16{}, F16{}, F16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(data_type == ConvDataType::BF16_BF16_BF16)
{ {
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, INT8{}, INT8{}, INT8{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_KYXGC_NHWGK) else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
{ {
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, F32{}, F32{}, F32{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
} }
else if(data_type == ConvDataType::F16_F16_F16) else if(data_type == ConvDataType::F16_F16_F16)
{ {
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, F16{}, F16{}, F16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
} }
else if(data_type == ConvDataType::BF16_BF16_BF16) else if(data_type == ConvDataType::BF16_BF16_BF16)
{ {
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{});
} }
else if(data_type == ConvDataType::INT8_INT8_INT8) else if(data_type == ConvDataType::INT8_INT8_INT8)
{ {
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment