Commit df35f46d authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 9c0811f3 7733ae16
...@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType, YDataType,
MeanDataType, MeanDataType,
InvStdDataType, InvStdDataType,
Shape>; Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>; using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
......
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp)
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_img2col -j
```
This will result in an executable `build/bin/tile_example_img2col`
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstring>
#include "ck_tile/host.hpp"
#include "image_to_column.hpp"
// Host API implementation
template <>
float image_to_column(const image_to_column_traits& traits,
const image_to_column_args<2>& args,
const ck_tile::stream_config& stream_conf)
{
if(traits.data_type.compare("fp16") == 0)
{
constexpr ck_tile::index_t NDimSpatial = 2;
constexpr ck_tile::index_t VectorSize = 8;
using thread_tile = ck_tile::sequence<8, 8>;
using warp_tile = ck_tile::sequence<64, 64>;
using block_tile = ck_tile::sequence<128, 128>;
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
using InDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using PipelineProblem = ck_tile::BlockImageToColumnProblem<InDataType,
OutDataType,
Shape,
NDimSpatial,
VectorSize,
VectorSize>;
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
auto kargs = Kernel::MakeKargs(args.p_in,
args.p_out,
args.G,
args.N,
args.C,
args.input_spatial_lengths,
args.filter_spatial_lengths,
args.output_spatial_lengths,
args.image_g_n_c_wis_strides,
args.gemm_g_m_k_strides,
args.conv_filter_strides,
args.conv_filter_dilations,
args.input_left_pads,
args.input_right_pads);
const dim3 grids = Kernel::GridSize(
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
args.G);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
float ave_time = ck_tile::launch_kernel(
stream_conf,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
int main(int argc, char* argv[])
{
constexpr ck_tile::index_t NDimSpatial = 2;
ExecutionConfig config;
ck_tile::conv::ConvParam conv_params = DefaultConvParams;
if(!parse_cmd_args(argc, argv, config, conv_params))
{
return EXIT_FAILURE;
}
if(conv_params.num_dim_spatial_ != NDimSpatial)
{
std::cerr << "unsupported # of spatial dimensions" << std::endl;
return EXIT_FAILURE;
}
using InDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
const auto G = conv_params.G_;
const auto N = conv_params.N_;
const auto C = conv_params.C_;
const ck_tile::long_index_t NHoWo =
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const ck_tile::long_index_t CYX =
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params);
const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX});
// host verify
ck_tile::HostTensor<InDataType> in(in_desc);
ck_tile::HostTensor<OutDataType> out_device(out_desc);
ck_tile::HostTensor<OutDataType> out_host(out_desc);
switch(config.init_method)
{
case 0: break;
case 1: ck_tile::FillUniformDistributionIntegerValue<InDataType>{-5.f, 5.f}(in); break;
default: ck_tile::FillUniformDistribution<InDataType>{-0.5, 0.5}(in); break;
}
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
in_device_buf.ToDevice(in.data());
image_to_column_traits traits{"fp16"};
image_to_column_args<NDimSpatial> args{
in_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
G,
N,
C,
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.filter_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.output_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_dilations_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_)};
float ave_time =
image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel});
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(config.do_verification)
{
// reference
ck_tile::reference_im2col<InDataType, OutDataType, NDimSpatial>(in, out_host, conv_params);
out_device_buf.FromDevice(out_device.data());
pass = ck_tile::check_err(out_device, out_host);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
return !pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
#include <string>
#define DefaultConvParams \
ck_tile::conv::ConvParam \
{ \
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ExecutionConfig& config,
ck_tile::conv::ConvParam& conv_params)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
config = ExecutionConfig{};
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params =
ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
struct image_to_column_traits
{
std::string data_type;
};
template <ck_tile::index_t NDimSpatial>
struct image_to_column_args
{
const void* p_in;
void* p_out;
const ck_tile::long_index_t G;
const ck_tile::long_index_t N;
const ck_tile::long_index_t C;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> filter_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> output_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
const ck_tile::array<ck_tile::long_index_t, 3> gemm_g_m_k_strides;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_strides;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_dilations;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_left_pads;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_right_pads;
};
// host API
template <ck_tile::index_t NDimSpatial>
float image_to_column(const image_to_column_traits&,
const image_to_column_args<NDimSpatial>&,
const ck_tile::stream_config&);
...@@ -5,3 +5,4 @@ include_directories(AFTER ...@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory(01_fmha) add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d) add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm) add_subdirectory(03_gemm)
add_subdirectory(04_img2col)
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<1>() __device__ constexpr auto TailScheduler<1>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<2>() __device__ constexpr auto TailScheduler<2>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
......
...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -324,55 +324,55 @@ struct DppSelector ...@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp(); static constexpr auto GetDpp();
template <> template <>
static constexpr auto GetDpp<half_t, 8, 32>() constexpr auto GetDpp<half_t, 8, 32>()
{ {
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 8, 16>() constexpr auto GetDpp<half_t, 8, 16>()
{ {
return DppInstr::dpp8_f16_8x16x2; return DppInstr::dpp8_f16_8x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() constexpr auto GetDpp<half_t, 16, 16>()
{ {
return DppInstr::dpp8_f16_16x16x2; return DppInstr::dpp8_f16_16x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 32, 8>() constexpr auto GetDpp<half_t, 32, 8>()
{ {
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 1, 32>() constexpr auto GetDpp<half_t, 1, 32>()
{ {
return DppInstr::dpp8_f16_1x32x2; return DppInstr::dpp8_f16_1x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 32>() constexpr auto GetDpp<half_t, 2, 32>()
{ {
return DppInstr::dpp8_f16_2x32x2; return DppInstr::dpp8_f16_2x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 16>() constexpr auto GetDpp<half_t, 2, 16>()
{ {
return DppInstr::dpp8_f16_2x16x2; return DppInstr::dpp8_f16_2x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 16>() constexpr auto GetDpp<half_t, 4, 16>()
{ {
return DppInstr::dpp8_f16_4x16x2; return DppInstr::dpp8_f16_4x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 32>() constexpr auto GetDpp<half_t, 4, 32>()
{ {
return DppInstr::dpp8_f16_4x32x2; return DppInstr::dpp8_f16_4x32x2;
} }
......
...@@ -415,7 +415,7 @@ struct WmmaSelector ...@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>() constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
...@@ -425,7 +425,7 @@ struct WmmaSelector ...@@ -425,7 +425,7 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
...@@ -435,19 +435,19 @@ struct WmmaSelector ...@@ -435,19 +435,19 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>() constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>() constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
...@@ -458,7 +458,7 @@ struct WmmaSelector ...@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>() constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -651,97 +651,97 @@ struct MfmaSelector ...@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
static constexpr auto GetMfma<double, 16, 16>() constexpr auto GetMfma<double, 16, 16>()
{ {
return MfmaInstr::mfma_f64_16x16x4f64; return MfmaInstr::mfma_f64_16x16x4f64;
} }
template <> template <>
static constexpr auto GetMfma<float, 64, 64>() constexpr auto GetMfma<float, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 64>() constexpr auto GetMfma<float, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 64>() constexpr auto GetMfma<float, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x1xf32; return MfmaInstr::mfma_f32_16x16x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 8, 64>() constexpr auto GetMfma<float, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 4, 64>() constexpr auto GetMfma<float, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 32>() constexpr auto GetMfma<float, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x2xf32; return MfmaInstr::mfma_f32_32x32x2xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 16>() constexpr auto GetMfma<float, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x4xf32; return MfmaInstr::mfma_f32_16x16x4xf32;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 64, 64>() constexpr auto GetMfma<half_t, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 64>() constexpr auto GetMfma<half_t, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x16f16; return MfmaInstr::mfma_f32_16x16x16f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 64>() constexpr auto GetMfma<half_t, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x4f16; return MfmaInstr::mfma_f32_16x16x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 8, 64>() constexpr auto GetMfma<half_t, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 4, 64>() constexpr auto GetMfma<half_t, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -751,7 +751,7 @@ struct MfmaSelector ...@@ -751,7 +751,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 16, 16>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -762,72 +762,72 @@ struct MfmaSelector ...@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940) #if defined(CK_USE_AMD_MFMA_GFX940)
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x16i8; return MfmaInstr::mfma_i32_32x32x16i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x32i8; return MfmaInstr::mfma_i32_16x16x32i8;
} }
#else #else
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x16i8; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif #endif
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() constexpr auto GetMfma<f8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8f8; return MfmaInstr::mfma_f32_32x32x16f8f8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16>() constexpr auto GetMfma<f8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8bf8; return MfmaInstr::mfma_f32_32x32x16bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16>() constexpr auto GetMfma<bf8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8bf8; return MfmaInstr::mfma_f32_32x32x16f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>() constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8f8; return MfmaInstr::mfma_f32_32x32x16bf8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>() constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
...@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr ...@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return !(a == b); return !(a == b);
} }
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
{
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
template <typename T, index_t N, typename X> template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x) CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{ {
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp" #include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp" #include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment