Commit be58e518 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 94642acf afbf6350
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(FMHA_CPP_FOLDER ${CMAKE_CURRENT_BINARY_DIR})
set(FMHA_SRC_FOLDER ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/)
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
# python stuff
find_package(PythonInterp 3 REQUIRED)
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
rocm_install(FILES
"${FMHA_SRC_FOLDER}/fmha_fwd.hpp"
"${FMHA_SRC_FOLDER}/bias.hpp"
"${FMHA_SRC_FOLDER}/mask.hpp"
DESTINATION include/ck_tile/ops
)
# header for building lib
file(COPY ${FMHA_SRC_FOLDER}/fmha_fwd.hpp DESTINATION ${FMHA_CPP_FOLDER})
file(COPY ${FMHA_SRC_FOLDER}/bias.hpp DESTINATION ${FMHA_CPP_FOLDER})
file(COPY ${FMHA_SRC_FOLDER}/mask.hpp DESTINATION ${FMHA_CPP_FOLDER})
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
)
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_FWD_GEN_BLOBS)
# actually generate the cpp files
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--output_dir ${FMHA_CPP_FOLDER}
COMMENT "Generating mha kernel (cpp) files now ..."
VERBATIM
)
# This is done to remove path info and just
# have filename. Since, it was cauing the cmake
# to throw "File name too long"
set(device_files)
foreach(filepath IN LISTS FMHA_FWD_GEN_BLOBS)
get_filename_component(filename ${filepath} NAME)
# Append the filename to the device_files list
list(APPEND device_files ${filename})
endforeach()
add_custom_target(generate_cpp_files DEPENDS ${FMHA_FWD_GEN_BLOBS})
add_instance_library(device_mha_instance ${device_files})
if (TARGET device_mha_instance)
add_dependencies(device_mha_instance generate_cpp_files)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/io.hpp"
......@@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim,
const std::vector<ck::index_t>& dilations,
const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads)
: num_dim_spatial_(static_cast<ck::long_index_t>(n_dim)),
G_(static_cast<ck::long_index_t>(group_count)),
N_(static_cast<ck::long_index_t>(n_batch)),
K_(static_cast<ck::long_index_t>(n_out_channels)),
C_(static_cast<ck::long_index_t>(n_in_channels)),
filter_spatial_lengths_(num_dim_spatial_),
input_spatial_lengths_(num_dim_spatial_),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(num_dim_spatial_),
conv_filter_dilations_(num_dim_spatial_),
input_left_pads_(num_dim_spatial_),
input_right_pads_(num_dim_spatial_)
{
if(static_cast<ck::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(
std::runtime_error("ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck::index_t i = 0; i < num_dim_spatial_; ++i)
{
filter_spatial_lengths_[i] = static_cast<ck::long_index_t>(filters_len[i]);
input_spatial_lengths_[i] = static_cast<ck::long_index_t>(input_len[i]);
conv_filter_strides_[i] = static_cast<ck::long_index_t>(strides[i]);
conv_filter_dilations_[i] = static_cast<ck::long_index_t>(dilations[i]);
input_left_pads_[i] = static_cast<ck::long_index_t>(left_pads[i]);
input_right_pads_[i] = static_cast<ck::long_index_t>(right_pads[i]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ConvParam::ConvParam(ck::long_index_t n_dim,
ck::long_index_t group_count,
ck::long_index_t n_batch,
ck::long_index_t n_out_channels,
ck::long_index_t n_in_channels,
const std::vector<ck::long_index_t>& filters_len,
const std::vector<ck::long_index_t>& input_len,
const std::vector<ck::long_index_t>& strides,
const std::vector<ck::long_index_t>& dilations,
const std::vector<ck::long_index_t>& left_pads,
const std::vector<ck::long_index_t>& right_pads)
: num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch),
......@@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim,
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
......@@ -63,7 +121,7 @@ ConvParam::ConvParam()
{
}
std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
std::vector<ck::long_index_t> ConvParam::GetOutputSpatialLengths() const
{
return output_spatial_lengths_;
}
......@@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg()
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck::index_t G = std::stoi(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
const ck::long_index_t G = std::stol(argv[arg_idx++]);
const ck::long_index_t N = std::stol(argv[arg_idx++]);
const ck::long_index_t K = std::stol(argv[arg_idx++]);
const ck::long_index_t C = std::stol(argv[arg_idx++]);
std::vector<ck::long_index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::long_index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::long_index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::long_index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::long_index_t> input_left_pads(num_dim_spatial);
std::vector<ck::long_index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stoi(argv[arg_idx++]);
input_left_pads[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stoi(argv[arg_idx++]);
input_right_pads[i] = std::stol(argv[arg_idx++]);
}
return ck::utils::conv::ConvParam{num_dim_spatial,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -82,6 +82,29 @@ bool profile_conv_bwd_data_impl(int do_verification,
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> output(out_g_n_k_wos_desc);
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
std::cout << "input: " << input_host_result.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
......@@ -161,16 +184,16 @@ bool profile_conv_bwd_data_impl(int do_verification,
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.output_spatial_lengths_,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
static_cast<ck::index_t>(conv_param.N_),
static_cast<ck::index_t>(conv_param.K_),
static_cast<ck::index_t>(conv_param.C_),
input_spatial_lengths_i32,
filter_spatial_lengths_i32,
output_spatial_lengths_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
in_element_op,
wei_element_op,
out_element_op);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -60,6 +60,29 @@ bool profile_conv_fwd_impl(int do_verification,
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << host_output.mDesc << std::endl;
......@@ -143,16 +166,16 @@ bool profile_conv_fwd_impl(int do_verification,
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
static_cast<ck::index_t>(conv_param.N_),
static_cast<ck::index_t>(conv_param.K_),
static_cast<ck::index_t>(conv_param.C_),
input_spatial_lengths_i32,
filter_spatial_lengths_i32,
output_spatial_lengths_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
in_element_op,
wei_element_op,
out_element_op);
......
......@@ -33,7 +33,8 @@ template <ck::index_t NDimSpatial,
typename WeiDataType,
typename OutDataType,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
typename BComputeType = AComputeType,
typename IndexType = ck::index_t>
bool profile_grouped_conv_fwd_impl(int do_verification,
int init_method,
bool do_log,
......@@ -57,16 +58,16 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<IndexType, NDimSpatial> conv_filter_strides{};
std::array<IndexType, NDimSpatial> conv_filter_dilations{};
std::array<IndexType, NDimSpatial> input_left_pads{};
std::array<IndexType, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
......
......@@ -29,6 +29,12 @@ enum struct ConvDataType
BF8_F8_F8, // 7
};
enum struct IndexType
{
INDEX_T, // 0
LONG_INDEX_T, // 1
};
#define OP_NAME "grouped_conv_fwd"
#define OP_DESC "Grouped Convolution Forward"
......@@ -45,12 +51,13 @@ static void print_helper_msg()
<< " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8\n"
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
// clang-format on
}
......@@ -60,7 +67,7 @@ static void print_helper_msg()
int profile_grouped_conv_fwd(int argc, char* argv[])
{
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
if(argc < 10)
{
print_helper_msg();
return 1;
......@@ -68,20 +75,21 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int num_dim_spatial = std::stoi(argv[8]);
const auto index_type = static_cast<IndexType>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const int init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const int num_dim_spatial = std::stoi(argv[9]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial)
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
{
print_helper_msg();
return 1;
}
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
using F32 = float;
using F16 = ck::half_t;
......@@ -138,18 +146,43 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using AComputeType = decltype(a_compute_type);
using BComputeType = decltype(b_compute_type);
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
AComputeType,
BComputeType>(
do_verification, init_method, do_log, time_kernel, params);
if(index_type == IndexType::INDEX_T)
{
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
AComputeType,
BComputeType,
ck::index_t>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
}
else if(index_type == IndexType::LONG_INDEX_T)
{
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
AComputeType,
BComputeType,
ck::long_index_t>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
return pass ? 0 : 1;
}
else
{
std::cout << "this indexing data type is not implemented" << std::endl;
return 1;
}
};
// GNHWC_GKYXC_GNHWK
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment