"...resnet50_tensorflow.git" did not exist on "9c314a0347e598b41d466f85f1c42e2682396df2"
Unverified Commit 146972f4 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

fix bug in gemm profiler (#344)

parent 75ab874e
......@@ -26,6 +26,7 @@ static constexpr auto ConvSpec =
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
#if 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
......@@ -78,6 +79,60 @@ using DeviceGroupedConvNDFwdInstance =
1,
S<1, 32, 1, 8>,
8>;
#else
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename BiasLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<BiasLayout>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasDataType>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
256, // MPerBlock
16, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
1, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 4>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
2, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
4, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 256, 1, 1>,
1>;
#endif
int main(int argc, char* argv[])
{
......
......@@ -131,11 +131,11 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/KYXGC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
KYXGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
......@@ -292,7 +292,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, KYXGC> && is_same_v<OutLayout, NHWGK>)
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
......@@ -302,7 +302,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
......
......@@ -5,8 +5,8 @@ set(DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp;
device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp;
# NHWGC, KYXGC, NHWGK
device_grouped_conv2d_fwd_xdl_nhwgc_kyxgc_nhwgk_f16_instance.cpp;
# NHWGC, GKYXC, NHWGK
device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp;
)
add_library(device_grouped_conv2d_fwd_instance OBJECT ${DEVICE_GROUPED_CONV2D_FWD_INSTANCE_SOURCE})
......
......@@ -72,43 +72,43 @@ int profile_gemm(int argc, char* argv[])
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto profile = [&](auto a_layout,
auto b_layout,
auto c_layout,
auto a_type,
auto b_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
auto c_type) {
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass =
ck::profiler::profile_gemm_impl<ADataType,
ck::profiler::profile_gemm_impl<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
CDataType>(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
return pass ? 0 : 1;
};
......
......@@ -13,7 +13,7 @@ namespace {
enum struct ConvLayout
{
GNHWC_GKYXC_GNHWK, // 0
NHWGC_KYXGC_NHWGK, // 1
NHWGC_GKYXC_NHWGK, // 1
};
enum struct ConvDataType
......@@ -34,7 +34,7 @@ static void print_helper_msg()
<< " 2: Input bf16, Weight bf16, Output bf16\n"
<< " 3: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[K, Y, X, G, C], Output[N, Ho, Wo, G, K])\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
......@@ -94,10 +94,6 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
using KXGC = ck::tensor_layout::convolution::KXGC;
using KYXGC = ck::tensor_layout::convolution::KYXGC;
using KZYXGC = ck::tensor_layout::convolution::KZYXGC;
using NWGK = ck::tensor_layout::convolution::NWGK;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
......@@ -193,62 +189,62 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{});
}
}
// NHWGC_KYXGC_NHWGK
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_KYXGC_NHWGK)
// NHWGC_GKYXC_NHWGK
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I1, NWGC{}, KXGC{}, NWGK{}, F32{}, F32{}, F32{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I1, NWGC{}, KXGC{}, NWGK{}, F16{}, F16{}, F16{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I1, NWGC{}, KXGC{}, NWGK{}, BF16{}, BF16{}, BF16{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I1, NWGC{}, KXGC{}, NWGK{}, INT8{}, INT8{}, INT8{});
return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_KYXGC_NHWGK)
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, F32{}, F32{}, F32{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, F16{}, F16{}, F16{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, BF16{}, BF16{}, BF16{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I2, NHWGC{}, KYXGC{}, NHWGK{}, INT8{}, INT8{}, INT8{});
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_KYXGC_NHWGK)
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, F32{}, F32{}, F32{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, F16{}, F16{}, F16{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, BF16{}, BF16{}, BF16{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return profile(I3, NDHWGC{}, KZYXGC{}, NDHWGK{}, INT8{}, INT8{}, INT8{});
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment