Unverified Commit a6390bbe authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into jizhan/enable_bf16_atomic_add

parents 2168c0d9 901e5f15
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD ...@@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_large_tensor_f32_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
ConvFwdDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(FMHA_CPP_FOLDER ${CMAKE_CURRENT_BINARY_DIR})
set(FMHA_SRC_FOLDER ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/)
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
# python stuff
find_package(PythonInterp 3 REQUIRED)
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
rocm_install(FILES
"${FMHA_SRC_FOLDER}/fmha_fwd.hpp"
"${FMHA_SRC_FOLDER}/bias.hpp"
"${FMHA_SRC_FOLDER}/mask.hpp"
DESTINATION include/ck_tile/ops
)
# header for building lib
file(COPY ${FMHA_SRC_FOLDER}/fmha_fwd.hpp DESTINATION ${FMHA_CPP_FOLDER})
file(COPY ${FMHA_SRC_FOLDER}/bias.hpp DESTINATION ${FMHA_CPP_FOLDER})
file(COPY ${FMHA_SRC_FOLDER}/mask.hpp DESTINATION ${FMHA_CPP_FOLDER})
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
)
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_FWD_GEN_BLOBS)
# actually generate the cpp files
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
--output_dir ${FMHA_CPP_FOLDER}
COMMENT "Generating mha kernel (cpp) files now ..."
VERBATIM
)
# This is done to remove path info and just
# have filename. Since, it was cauing the cmake
# to throw "File name too long"
set(device_files)
foreach(filepath IN LISTS FMHA_FWD_GEN_BLOBS)
get_filename_component(filename ${filepath} NAME)
# Append the filename to the device_files list
list(APPEND device_files ${filename})
endforeach()
add_custom_target(generate_cpp_files DEPENDS ${FMHA_FWD_GEN_BLOBS})
add_instance_library(device_mha_instance ${device_files})
if (TARGET device_mha_instance)
add_dependencies(device_mha_instance generate_cpp_files)
endif()
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
...@@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim, ...@@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim,
const std::vector<ck::index_t>& dilations, const std::vector<ck::index_t>& dilations,
const std::vector<ck::index_t>& left_pads, const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads) const std::vector<ck::index_t>& right_pads)
: num_dim_spatial_(static_cast<ck::long_index_t>(n_dim)),
G_(static_cast<ck::long_index_t>(group_count)),
N_(static_cast<ck::long_index_t>(n_batch)),
K_(static_cast<ck::long_index_t>(n_out_channels)),
C_(static_cast<ck::long_index_t>(n_in_channels)),
filter_spatial_lengths_(num_dim_spatial_),
input_spatial_lengths_(num_dim_spatial_),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(num_dim_spatial_),
conv_filter_dilations_(num_dim_spatial_),
input_left_pads_(num_dim_spatial_),
input_right_pads_(num_dim_spatial_)
{
if(static_cast<ck::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(
std::runtime_error("ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck::index_t i = 0; i < num_dim_spatial_; ++i)
{
filter_spatial_lengths_[i] = static_cast<ck::long_index_t>(filters_len[i]);
input_spatial_lengths_[i] = static_cast<ck::long_index_t>(input_len[i]);
conv_filter_strides_[i] = static_cast<ck::long_index_t>(strides[i]);
conv_filter_dilations_[i] = static_cast<ck::long_index_t>(dilations[i]);
input_left_pads_[i] = static_cast<ck::long_index_t>(left_pads[i]);
input_right_pads_[i] = static_cast<ck::long_index_t>(right_pads[i]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ConvParam::ConvParam(ck::long_index_t n_dim,
ck::long_index_t group_count,
ck::long_index_t n_batch,
ck::long_index_t n_out_channels,
ck::long_index_t n_in_channels,
const std::vector<ck::long_index_t>& filters_len,
const std::vector<ck::long_index_t>& input_len,
const std::vector<ck::long_index_t>& strides,
const std::vector<ck::long_index_t>& dilations,
const std::vector<ck::long_index_t>& left_pads,
const std::vector<ck::long_index_t>& right_pads)
: num_dim_spatial_(n_dim), : num_dim_spatial_(n_dim),
G_(group_count), G_(group_count),
N_(n_batch), N_(n_batch),
...@@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim, ...@@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim,
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] = output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
...@@ -63,7 +121,7 @@ ConvParam::ConvParam() ...@@ -63,7 +121,7 @@ ConvParam::ConvParam()
{ {
} }
std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const std::vector<ck::long_index_t> ConvParam::GetOutputSpatialLengths() const
{ {
return output_spatial_lengths_; return output_spatial_lengths_;
} }
...@@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg() ...@@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg()
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]) ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{ {
const ck::index_t G = std::stoi(argv[arg_idx++]); const ck::long_index_t G = std::stol(argv[arg_idx++]);
const ck::index_t N = std::stoi(argv[arg_idx++]); const ck::long_index_t N = std::stol(argv[arg_idx++]);
const ck::index_t K = std::stoi(argv[arg_idx++]); const ck::long_index_t K = std::stol(argv[arg_idx++]);
const ck::index_t C = std::stoi(argv[arg_idx++]); const ck::long_index_t C = std::stol(argv[arg_idx++]);
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial); std::vector<ck::long_index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial); std::vector<ck::long_index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial); std::vector<ck::long_index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial); std::vector<ck::long_index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck::index_t> input_left_pads(num_dim_spatial); std::vector<ck::long_index_t> input_left_pads(num_dim_spatial);
std::vector<ck::index_t> input_right_pads(num_dim_spatial); std::vector<ck::long_index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
} }
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
} }
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
conv_filter_strides[i] = std::stoi(argv[arg_idx++]); conv_filter_strides[i] = std::stol(argv[arg_idx++]);
} }
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
} }
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
input_left_pads[i] = std::stoi(argv[arg_idx++]); input_left_pads[i] = std::stol(argv[arg_idx++]);
} }
for(int i = 0; i < num_dim_spatial; ++i) for(int i = 0; i < num_dim_spatial; ++i)
{ {
input_right_pads[i] = std::stoi(argv[arg_idx++]); input_right_pads[i] = std::stol(argv[arg_idx++]);
} }
return ck::utils::conv::ConvParam{num_dim_spatial, return ck::utils::conv::ConvParam{num_dim_spatial,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -82,6 +82,29 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -82,6 +82,29 @@ bool profile_conv_bwd_data_impl(int do_verification,
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc); Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> output(out_g_n_k_wos_desc); Tensor<OutDataType> output(out_g_n_k_wos_desc);
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
std::cout << "input: " << input_host_result.mDesc << std::endl; std::cout << "input: " << input_host_result.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl;
...@@ -161,16 +184,16 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -161,16 +184,16 @@ bool profile_conv_bwd_data_impl(int do_verification,
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_, static_cast<ck::index_t>(conv_param.N_),
conv_param.K_, static_cast<ck::index_t>(conv_param.K_),
conv_param.C_, static_cast<ck::index_t>(conv_param.C_),
conv_param.input_spatial_lengths_, input_spatial_lengths_i32,
conv_param.filter_spatial_lengths_, filter_spatial_lengths_i32,
conv_param.output_spatial_lengths_, output_spatial_lengths_i32,
conv_param.conv_filter_strides_, conv_filter_strides_i32,
conv_param.conv_filter_dilations_, conv_filter_dilations_i32,
conv_param.input_left_pads_, input_left_pads_i32,
conv_param.input_right_pads_, input_right_pads_i32,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -60,6 +60,29 @@ bool profile_conv_fwd_impl(int do_verification, ...@@ -60,6 +60,29 @@ bool profile_conv_fwd_impl(int do_verification,
Tensor<OutDataType> host_output(out_g_n_k_wos_desc); Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
Tensor<OutDataType> device_output(out_g_n_k_wos_desc); Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
std::cout << "input: " << input.mDesc << std::endl; std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << host_output.mDesc << std::endl; std::cout << "output: " << host_output.mDesc << std::endl;
...@@ -143,16 +166,16 @@ bool profile_conv_fwd_impl(int do_verification, ...@@ -143,16 +166,16 @@ bool profile_conv_fwd_impl(int do_verification,
op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_, static_cast<ck::index_t>(conv_param.N_),
conv_param.K_, static_cast<ck::index_t>(conv_param.K_),
conv_param.C_, static_cast<ck::index_t>(conv_param.C_),
conv_param.input_spatial_lengths_, input_spatial_lengths_i32,
conv_param.filter_spatial_lengths_, filter_spatial_lengths_i32,
conv_param.GetOutputSpatialLengths(), output_spatial_lengths_i32,
conv_param.conv_filter_strides_, conv_filter_strides_i32,
conv_param.conv_filter_dilations_, conv_filter_dilations_i32,
conv_param.input_left_pads_, input_left_pads_i32,
conv_param.input_right_pads_, input_right_pads_i32,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); out_element_op);
......
...@@ -33,7 +33,8 @@ template <ck::index_t NDimSpatial, ...@@ -33,7 +33,8 @@ template <ck::index_t NDimSpatial,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AComputeType = InDataType, typename AComputeType = InDataType,
typename BComputeType = AComputeType> typename BComputeType = AComputeType,
typename IndexType = ck::index_t>
bool profile_grouped_conv_fwd_impl(int do_verification, bool profile_grouped_conv_fwd_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -57,16 +58,16 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -57,16 +58,16 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
const auto out_g_n_k_wos_desc = const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param); ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{}; std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{}; std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{}; std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{}; std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{}; std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{}; std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<IndexType, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<IndexType, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<IndexType, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{}; std::array<IndexType, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
......
...@@ -46,8 +46,10 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -46,8 +46,10 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
endif() endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) if(GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
...@@ -128,8 +130,10 @@ if(GPU_TARGETS MATCHES "gfx9") ...@@ -128,8 +130,10 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) if(GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance)
......
...@@ -29,6 +29,12 @@ enum struct ConvDataType ...@@ -29,6 +29,12 @@ enum struct ConvDataType
BF8_F8_F8, // 7 BF8_F8_F8, // 7
}; };
enum struct IndexType
{
INDEX_T, // 0
LONG_INDEX_T, // 1
};
#define OP_NAME "grouped_conv_fwd" #define OP_NAME "grouped_conv_fwd"
#define OP_DESC "Grouped Convolution Forward" #define OP_DESC "Grouped Convolution Forward"
...@@ -45,12 +51,13 @@ static void print_helper_msg() ...@@ -45,12 +51,13 @@ static void print_helper_msg()
<< " 5: Input bf8, Weight bf8, Output fp8\n" << " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n"
<< " 7: Input bf8, Weight fp8, Output fp8)\n" << " 7: Input bf8, Weight fp8, Output fp8)\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg4: verification (0: no, 1: yes)\n" << "arg5: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
// clang-format on // clang-format on
} }
...@@ -60,7 +67,7 @@ static void print_helper_msg() ...@@ -60,7 +67,7 @@ static void print_helper_msg()
int profile_grouped_conv_fwd(int argc, char* argv[]) int profile_grouped_conv_fwd(int argc, char* argv[])
{ {
// 8 for control, 1 for num_dim_spatial // 8 for control, 1 for num_dim_spatial
if(argc < 9) if(argc < 10)
{ {
print_helper_msg(); print_helper_msg();
return 1; return 1;
...@@ -68,20 +75,21 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -68,20 +75,21 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3])); const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const auto index_type = static_cast<IndexType>(std::stoi(argv[4]));
const int init_method = std::stoi(argv[5]); const bool do_verification = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const int init_method = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]); const bool do_log = std::stoi(argv[7]);
const int num_dim_spatial = std::stoi(argv[8]); const bool time_kernel = std::stoi(argv[8]);
const int num_dim_spatial = std::stoi(argv[9]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
{ {
print_helper_msg(); print_helper_msg();
return 1; return 1;
} }
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -138,18 +146,43 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -138,18 +146,43 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using AComputeType = decltype(a_compute_type); using AComputeType = decltype(a_compute_type);
using BComputeType = decltype(b_compute_type); using BComputeType = decltype(b_compute_type);
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial, if(index_type == IndexType::INDEX_T)
InLayout, {
WeiLayout, bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
OutLayout, InLayout,
InDataType, WeiLayout,
WeiDataType, OutLayout,
OutDataType, InDataType,
AComputeType, WeiDataType,
BComputeType>( OutDataType,
do_verification, init_method, do_log, time_kernel, params); AComputeType,
BComputeType,
ck::index_t>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
}
else if(index_type == IndexType::LONG_INDEX_T)
{
bool pass = ck::profiler::profile_grouped_conv_fwd_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
AComputeType,
BComputeType,
ck::long_index_t>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1; return pass ? 0 : 1;
}
else
{
std::cout << "this indexing data type is not implemented" << std::endl;
return 1;
}
}; };
// GNHWC_GKYXC_GNHWK // GNHWC_GKYXC_GNHWK
......
...@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: ...@@ -62,17 +62,13 @@ def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
i_current = i_next + 1 i_current = i_next + 1
if i_next == -1: if i_next == -1:
break break
# pad with `None`s for the fields which are not defined in the instance
template_args.insert(2, tuple()) # ds layout
template_args.insert(6, tuple()) # ds dtype
new_instance = CKGemmOperation( new_instance = CKGemmOperation(
*template_args, # type: ignore[arg-type] *template_args, # type: ignore[arg-type]
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
) )
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if new_instance.a_compute_dtype is None:
new_instance.a_compute_dtype = new_instance.c_element_dtype
if new_instance.b_compute_dtype is None:
new_instance.b_compute_dtype = new_instance.c_element_dtype
op_instances.append(new_instance) op_instances.append(new_instance)
return op_instances return op_instances
...@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]: ...@@ -208,6 +204,8 @@ def gen_ops_preselected() -> List[CKGemmOperation]:
a_layout="Row", a_layout="Row",
b_layout="Col", b_layout="Col",
c_layout="Row", c_layout="Row",
ds_element_dtypes=tuple(),
ds_layouts=tuple(),
a_element_dtype="F16", a_element_dtype="F16",
b_element_dtype="F16", b_element_dtype="F16",
c_element_dtype="F16", c_element_dtype="F16",
......
...@@ -10,10 +10,12 @@ class CKGemmOperation: ...@@ -10,10 +10,12 @@ class CKGemmOperation:
a_layout: str a_layout: str
b_layout: str b_layout: str
ds_layouts: Tuple[str] # addmm specific
c_layout: str c_layout: str
a_element_dtype: str a_element_dtype: str
b_element_dtype: str b_element_dtype: str
ds_element_dtypes: Tuple[str] # addmm specific
c_element_dtype: str c_element_dtype: str
acc_dtype: str acc_dtype: str
...@@ -64,16 +66,15 @@ class CKGemmOperation: ...@@ -64,16 +66,15 @@ class CKGemmOperation:
Tuple[int, int, int, int] Tuple[int, int, int, int]
) )
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: Optional[str] block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] b_compute_dtype: Optional[str] = None
def name(self): def name(self):
# cpp alias for template instance # cpp alias for template instance
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" return f"ck_devicegemm_multid_xdl_shuffle_v3_{self.key_name()}"
def key_name(self): def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key # TBD; must be unique per instance. Intended to use as dict key
......
...@@ -143,6 +143,12 @@ def parse_logfile(logfile): ...@@ -143,6 +143,12 @@ def parse_logfile(logfile):
if 'Best Perf' in line: if 'Best Perf' in line:
lst=line.split() lst=line.split()
res.append(lst[36]) res.append(lst[36])
elif 'perf_fmha' in logfile:
for line in open(logfile):
if 'TFlops' in line:
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
return res return res
...@@ -304,6 +310,14 @@ def main(): ...@@ -304,6 +310,14 @@ def main():
for i in range(1,len(results)+1): for i in range(1,len(results)+1):
testlist.append("Test%i"%i) testlist.append("Test%i"%i)
table_name="ck_mixed_gemm_tflops" table_name="ck_mixed_gemm_tflops"
if 'fmha_fwd' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_fmha_fwd_tflops"
if 'fmha_bwd' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_fmha_bwd_tflops"
tflops_base = get_baseline(table_name,conn) tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn)
......
...@@ -13,3 +13,20 @@ ...@@ -13,3 +13,20 @@
python3 process_perf_data.py perf_gemm.log python3 process_perf_data.py perf_gemm.log
python3 process_perf_data.py perf_resnet50_N256.log python3 process_perf_data.py perf_resnet50_N256.log
python3 process_perf_data.py perf_resnet50_N4.log python3 process_perf_data.py perf_resnet50_N4.log
file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log
fi
file=./perf_fmha_bwd_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx942.log
fi
file=./perf_fmha_fwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx90a.log
fi
file=./perf_fmha_bwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
...@@ -21,3 +21,20 @@ python3 process_perf_data.py perf_gemm_bilinear.log ...@@ -21,3 +21,20 @@ python3 process_perf_data.py perf_gemm_bilinear.log
python3 process_perf_data.py perf_reduction.log python3 process_perf_data.py perf_reduction.log
python3 process_perf_data.py perf_splitK_gemm.log python3 process_perf_data.py perf_splitK_gemm.log
python3 process_perf_data.py perf_onnx_gemm.log python3 process_perf_data.py perf_onnx_gemm.log
file=./perf_fmha_fwd_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx942.log
fi
file=./perf_fmha_bwd_gfx942.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx942.log
fi
file=./perf_fmha_fwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_fwd_gfx90a.log
fi
file=./perf_fmha_bwd_gfx90a.log
if [ -e "$file" ]; then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <string> #include <string>
...@@ -24,12 +24,12 @@ class TestConvUtil : public ::testing::Test ...@@ -24,12 +24,12 @@ class TestConvUtil : public ::testing::Test
128, 128,
192, 192,
256, 256,
std::vector<ck::index_t>(ndims, 3), std::vector<ck::long_index_t>(ndims, 3),
std::vector<ck::index_t>(ndims, 71), std::vector<ck::long_index_t>(ndims, 71),
std::vector<ck::index_t>(ndims, s), std::vector<ck::long_index_t>(ndims, s),
std::vector<ck::index_t>(ndims, d), std::vector<ck::long_index_t>(ndims, d),
std::vector<ck::index_t>(ndims, p), std::vector<ck::long_index_t>(ndims, p),
std::vector<ck::index_t>(ndims, p)); std::vector<ck::long_index_t>(ndims, p));
} }
protected: protected:
...@@ -48,35 +48,35 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) ...@@ -48,35 +48,35 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
{ {
// stride 2, dilation 1, pad 1 // stride 2, dilation 1, pad 1
SetNDParams(1, 2, 1, 1); SetNDParams(1, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths(); std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D.")); out_spatial_len, std::vector<ck::long_index_t>{36}, "Error: ConvParams 1D."));
// stride 1, dilation 1, pad 1 // stride 1, dilation 1, pad 1
SetNDParams(1, 1, 1, 1); SetNDParams(1, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}.")); out_spatial_len, std::vector<ck::long_index_t>{71}, "Error: ConvParams 1D stride {1}."));
// stride 2, dilation 1, pad 2 // stride 2, dilation 1, pad 2
SetNDParams(1, 2, 1, 2); SetNDParams(1, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{37}, std::vector<ck::long_index_t>{37},
"Error: ConvParams 1D padding left/right {2}.")); "Error: ConvParams 1D padding left/right {2}."));
// stride 2, dilation 2, pad 2 // stride 2, dilation 2, pad 2
SetNDParams(1, 2, 2, 2); SetNDParams(1, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}.")); out_spatial_len, std::vector<ck::long_index_t>{36}, "Error: ConvParams 1D dilation {2}."));
// stride 3, dilation 2, pad 1 // stride 3, dilation 2, pad 1
SetNDParams(1, 3, 2, 1); SetNDParams(1, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE( EXPECT_TRUE(
ck::utils::check_err(out_spatial_len, ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{23}, std::vector<ck::long_index_t>{23},
"Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.")); "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
} }
...@@ -84,36 +84,38 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) ...@@ -84,36 +84,38 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
{ {
// stride 2, dilation 1, pad 1 // stride 2, dilation 1, pad 1
SetNDParams(2, 2, 1, 1); SetNDParams(2, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths(); std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36}, std::vector<ck::long_index_t>{36, 36},
"Error: ConvParams 2D default constructor.")); "Error: ConvParams 2D default constructor."));
// stride 1, dilation 1, pad 1 // stride 1, dilation 1, pad 1
SetNDParams(2, 1, 1, 1); SetNDParams(2, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}.")); std::vector<ck::long_index_t>{71, 71},
"Error: ConvParams 2D stride {1,1}."));
// stride 2, dilation 1, pad 2 // stride 2, dilation 1, pad 2
SetNDParams(2, 2, 1, 2); SetNDParams(2, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37}, std::vector<ck::long_index_t>{37, 37},
"Error: ConvParams 2D padding left/right {2,2}.")); "Error: ConvParams 2D padding left/right {2,2}."));
// stride 2, dilation 2, pad 2 // stride 2, dilation 2, pad 2
SetNDParams(2, 2, 2, 2); SetNDParams(2, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); std::vector<ck::long_index_t>{36, 36},
"Error: ConvParams 2D dilation {2,2}."));
// stride 3, dilation 2, pad 1 // stride 3, dilation 2, pad 1
SetNDParams(2, 3, 2, 1); SetNDParams(2, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE( EXPECT_TRUE(
ck::utils::check_err(out_spatial_len, ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{23, 23}, std::vector<ck::long_index_t>{23, 23},
"Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.")); "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
} }
...@@ -121,29 +123,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) ...@@ -121,29 +123,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{ {
// stride 2, dilation 1, pad 1 // stride 2, dilation 1, pad 1
SetNDParams(3, 2, 1, 1); SetNDParams(3, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths(); std::vector<ck::long_index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D.")); out_spatial_len, std::vector<ck::long_index_t>{36, 36, 36}, "Error: ConvParams 3D."));
// stride 1, dilation 1, pad 1 // stride 1, dilation 1, pad 1
SetNDParams(3, 1, 1, 1); SetNDParams(3, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{71, 71, 71}, std::vector<ck::long_index_t>{71, 71, 71},
"Error: ConvParams 3D stride {1, 1, 1}.")); "Error: ConvParams 3D stride {1, 1, 1}."));
// stride 2, dilation 1, pad 2 // stride 2, dilation 1, pad 2
SetNDParams(3, 2, 1, 2); SetNDParams(3, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37, 37}, std::vector<ck::long_index_t>{37, 37, 37},
"Error: ConvParams 3D padding left/right {2, 2, 2}.")); "Error: ConvParams 3D padding left/right {2, 2, 2}."));
// stride 2, dilation 2, pad 2 // stride 2, dilation 2, pad 2
SetNDParams(3, 2, 2, 2); SetNDParams(3, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len, EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36, 36}, std::vector<ck::long_index_t>{36, 36, 36},
"Error: ConvParams 3D dilation {2, 2, 2}.")); "Error: ConvParams 3D dilation {2, 2, 2}."));
// stride 3, dilation 2, pad 1 // stride 3, dilation 2, pad 1
...@@ -151,6 +153,6 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) ...@@ -151,6 +153,6 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
out_spatial_len = conv_params.GetOutputSpatialLengths(); out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err( EXPECT_TRUE(ck::utils::check_err(
out_spatial_len, out_spatial_len,
std::vector<ck::index_t>{23, 23, 23}, std::vector<ck::long_index_t>{23, 23, 23},
"Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.")); "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
} }
...@@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test ...@@ -17,6 +17,7 @@ class TestGroupedConvndFwd : public ::testing::Test
using InLayout = std::tuple_element_t<1, Tuple>; using InLayout = std::tuple_element_t<1, Tuple>;
using WeiLayout = std::tuple_element_t<2, Tuple>; using WeiLayout = std::tuple_element_t<2, Tuple>;
using OutLayout = std::tuple_element_t<3, Tuple>; using OutLayout = std::tuple_element_t<3, Tuple>;
using IndexType = std::tuple_element_t<4, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
...@@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test ...@@ -33,7 +34,10 @@ class TestGroupedConvndFwd : public ::testing::Test
OutLayout, OutLayout,
DataType, DataType,
DataType, DataType,
DataType>( DataType,
DataType,
DataType,
IndexType>(
true, // do_verification true, // do_verification
1, // init_method: integer value 1, // init_method: integer value
false, // do_log false, // do_log
...@@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test ...@@ -46,30 +50,31 @@ class TestGroupedConvndFwd : public ::testing::Test
using namespace ck::tensor_layout::convolution; using namespace ck::tensor_layout::convolution;
using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK>, using KernelTypes1d = ::testing::Types<std::tuple<float, GNWC, GKXC, GNWK, ck::index_t>,
std::tuple<ck::half_t, GNWC, GKXC, GNWK>, std::tuple<ck::half_t, GNWC, GKXC, GNWK, ck::index_t>,
std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK>, std::tuple<ck::bhalf_t, GNWC, GKXC, GNWK, ck::index_t>,
std::tuple<int8_t, GNWC, GKXC, GNWK>>; std::tuple<int8_t, GNWC, GKXC, GNWK, ck::index_t>>;
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>, using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK>, std::tuple<ck::half_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK>, std::tuple<ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<int8_t, GNHWC, GKYXC, GNHWK>, std::tuple<int8_t, GNHWC, GKYXC, GNHWK, ck::index_t>,
std::tuple<float, NHWGC, GKYXC, NHWGK>, std::tuple<float, NHWGC, GKYXC, NHWGK, ck::index_t>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK, ck::index_t>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::index_t>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>; std::tuple<int8_t, NHWGC, GKYXC, NHWGK, ck::index_t>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>, using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<int8_t, GNDHWC, GKZYXC, GNDHWK, ck::index_t>,
std::tuple<float, NDHWGC, GKZYXC, NDHWGK>, std::tuple<float, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK>, std::tuple<ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK>, std::tuple<ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>,
std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK>>; std::tuple<int8_t, NDHWGC, GKZYXC, NDHWGK, ck::index_t>>;
using KernelTypes2dLargeCases = ::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK>>; using KernelTypes2dLargeCases =
::testing::Types<std::tuple<float, NHWGC, GKYXC, NHWGK, ck::long_index_t>>;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple> class TestGroupedConvndFwd1d : public TestGroupedConvndFwd<Tuple>
...@@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) ...@@ -153,5 +158,8 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases)
// With supported NumGroupsToMerge > 1 // With supported NumGroupsToMerge > 1
this->conv_params.push_back( this->conv_params.push_back(
{2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}});
// When image is larger than 2GB
this->conv_params.push_back(
{2, 1, 1, 256, 256, {3, 3}, {4096, 2048}, {1024, 1024}, {3, 3}, {1, 1}, {1, 1}});
this->template Run<2>(); this->template Run<2>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment