Commit df35f46d authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 9c0811f3 7733ae16
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
namespace conv {
namespace detail {
template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
{
return {0, 1, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
{
return {2, 0, 3, 1};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
return {4, 0, 5, 1, 2, 3};
}
else
{
printf("%s\n", __func__);
throw std::runtime_error("wrong! unsupported layout");
}
}
} // namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", InLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<InLayout>());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", WeiLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", OutLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
}
} // namespace conv
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace ck_tile {
namespace conv {
struct ConvParam
{
ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count,
ck_tile::index_t n_batch,
ck_tile::index_t n_out_channels,
ck_tile::index_t n_in_channels,
const std::vector<ck_tile::index_t>& filters_len,
const std::vector<ck_tile::index_t>& input_len,
const std::vector<ck_tile::index_t>& strides,
const std::vector<ck_tile::index_t>& dilations,
const std::vector<ck_tile::index_t>& left_pads,
const std::vector<ck_tile::index_t>& right_pads)
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
G_(static_cast<ck_tile::long_index_t>(group_count)),
N_(static_cast<ck_tile::long_index_t>(n_batch)),
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
filter_spatial_lengths_(num_dim_spatial_),
input_spatial_lengths_(num_dim_spatial_),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(num_dim_spatial_),
conv_filter_dilations_(num_dim_spatial_),
input_left_pads_(num_dim_spatial_),
input_right_pads_(num_dim_spatial_)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ConvParam(ck_tile::long_index_t n_dim,
ck_tile::long_index_t group_count,
ck_tile::long_index_t n_batch,
ck_tile::long_index_t n_out_channels,
ck_tile::long_index_t n_in_channels,
const std::vector<ck_tile::long_index_t>& filters_len,
const std::vector<ck_tile::long_index_t>& input_len,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads,
const std::vector<ck_tile::long_index_t>& right_pads)
: num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch),
K_(n_out_channels),
C_(n_in_channels),
filter_spatial_lengths_(filters_len),
input_spatial_lengths_(input_len),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(strides),
conv_filter_dilations_(dilations),
input_left_pads_(left_pads),
input_right_pads_(right_pads)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ck_tile::long_index_t num_dim_spatial_;
ck_tile::long_index_t G_;
ck_tile::long_index_t N_;
ck_tile::long_index_t K_;
ck_tile::long_index_t C_;
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
std::vector<ck_tile::long_index_t> conv_filter_strides_;
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
std::vector<ck_tile::long_index_t> input_left_pads_;
std::vector<ck_tile::long_index_t> input_right_pads_;
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
{
return output_spatial_lengths_;
}
std::size_t GetFlops() const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>());
}
template <typename InDataType>
std::size_t GetInputByte() const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename WeiDataType>
std::size_t GetWeightByte() const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename OutDataType>
std::size_t GetOutputByte() const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
GetOutputByte<OutDataType>();
}
};
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{
std::string msg;
msg += "Following arguments (depending on number of spatial dims):\n"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
" G, N, K, C, \n"
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
" <strides>, (ie Sy, Sx for 2D)\n"
" <dilations>, (ie Dy, Dx for 2D)\n"
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
" <right padding>, (ie RightPy, RightPx for 2D)\n";
return msg;
}
CK_TILE_HOST ck_tile::conv::ConvParam
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stol(argv[arg_idx++]);
}
return ck_tile::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
} // namespace conv
} // namespace ck_tile
......@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
{
os << "dim " << desc.get_num_of_dimension() << ", ";
os << "lengths {";
LogRange(os, desc.get_lengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.get_strides(), ", ");
os << "}";
return os;
}
private:
std::vector<std::size_t> mLens;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,53 +9,125 @@
namespace ck_tile {
template <typename T>
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
HostTensor<OutDataType>& out_host,
const ck_tile::conv::ConvParam& conv_params)
{
int GemmM = in_mtx_host_ref.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1];
const long_index_t G = in_host.get_lengths()[0];
const long_index_t N = in_host.get_lengths()[1];
const long_index_t C = in_host.get_lengths()[2];
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m)
if constexpr(NDimSpatial == 1)
{
int mtmp = gemm_m;
int n = mtmp / (Ho * Wo);
mtmp -= n * Ho * Wo;
int ho = mtmp / Wo;
int wo = mtmp - ho * Wo;
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k)
{
int ktmp = gemm_k;
int y = ktmp / (X * C);
ktmp -= y * X * C;
int x = ktmp / C;
int c = ktmp - x * C;
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH;
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW;
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi);
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0;
}
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) {
long_index_t row = n * Wo + wo;
long_index_t column = 0;
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t c = 0; c < C; ++c)
{
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
{
InDataType v_in = in_host(g, n, c, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
};
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) {
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
{
InDataType v_in = in_host(g, n, c, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t Do = conv_params.output_spatial_lengths_[0];
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
hi >= 0 &&
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
{
InDataType v_in = in_host(g, n, c, di, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
}
}
} // namespace ck_tile
......@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
const index_t split_end = split_start + x_per_split;
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));
......
......@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
......@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaBwdCommonDropoutKargs
struct FmhaBwdDropoutSeedOffset
{
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const float raw_scale)
template <typename T>
union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
......@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
this->drop_seed.ptr = seed_ptr;
this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
}
float rp_undrop = 1;
float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
......@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
......@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
......@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
......
......@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
......@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdCommonDropoutKargs
struct FmhaFwdDropoutSeedOffset
{
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
template <typename T>
union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
this->drop_seed.ptr = seed_ptr;
this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
}
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
......@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
......@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
......@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t,
kargs.is_store_randval};
......
......@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void* o_ptr;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_splits;
......@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
......@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
};
struct GroupModeKargs
......@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
......@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
seqlen_q,
hdim_v,
num_splits,
......@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o,
batch_stride_lse_acc};
batch_stride_lse_acc,
batch_stride_o_acc,
batch_stride_o};
if constexpr(kStoreLSE)
{
......@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
......@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
......@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
hdim_v,
num_splits,
......@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
......@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
......@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
......@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// for simplicity, batch stride we just modify the pointer
......@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
......@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q =
const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
......@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
kargs.max_seqlen_q,
kargs.seqlen_q,
smem_ptr);
}
else
......@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
kargs.max_seqlen_q,
kargs.seqlen_q,
smem_ptr);
}
}();
......
......@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
......
......@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
......@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
};
struct GroupModeKargs
......@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
......@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for bias
......@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q,
batch_stride_k,
batch_stride_v};
batch_stride_v,
batch_stride_lse_acc,
batch_stride_o_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
......@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_k, // only used for paged-kvcache
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
......@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for bias
......@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits);
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_o_acc = 0;
if constexpr(kIsGroupMode)
{
......@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
......@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias = query_start * kargs.stride_bias + key_start;
}
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.stride_o_acc;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
......@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
......@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.hdim_v, 1),
number<FmhaPipeline::kAlignmentO>{},
make_tuple(kargs.stride_o_acc, 1),
number<1>{},
number<1>{});
return pad_tensor_view(
......
......@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
batch_size);
......
......@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
s_acc,
bias_s_tile);
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
......@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{};
......@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{};
......@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
__builtin_amdgcn_sched_barrier(0);
}
// STAGE 6, SGrad^T@Q^T Gemm3
......@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
......@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
HotLoopScheduler::template GemmStagedScheduler<4>();
__builtin_amdgcn_sched_barrier(0);
// Results Scale
if constexpr(FmhaDropout::IsDropout)
......
......@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
......@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
......@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
......@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
......@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
......
......@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
index_t max_seqlen_q,
index_t seqlen_q,
void* smem_ptr) const
{
// lse_acc tile in LDS
......@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
......@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
......@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
index_t max_seqlen_q,
index_t seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
......@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{},
identity{},
num_splits,
max_seqlen_q,
seqlen_q,
smem_ptr);
}
};
......
......@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
......@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
......@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename Problem_>
struct ImageToColumn
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static constexpr auto I4 = number<4>{};
using Problem = remove_cvref_t<Problem_>;
using InDataType = remove_cvref_t<typename Problem::InDataType>;
using OutDataType = remove_cvref_t<typename Problem::OutDataType>;
static constexpr index_t NDimSpatial = Problem::NDimSpatial;
static constexpr index_t AligmentIn = Problem::AligmentIn;
static constexpr index_t AligmentOut = Problem::AligmentOut;
static_assert(NDimSpatial == 2, "Not supported.");
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
struct Kargs
{
const void* p_in;
void* p_out;
const long_index_t G;
const long_index_t N;
const long_index_t C;
const array<long_index_t, NDimSpatial> input_spatial_lengths;
const array<long_index_t, NDimSpatial> filter_spatial_lengths;
const array<long_index_t, NDimSpatial> output_spatial_lengths;
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
const array<long_index_t, 3> gemm_g_m_k_strides;
const array<long_index_t, NDimSpatial> conv_filter_strides;
const array<long_index_t, NDimSpatial> conv_filter_dilations;
const array<long_index_t, NDimSpatial> input_left_pads;
const array<long_index_t, NDimSpatial> input_right_pads;
};
CK_TILE_HOST static constexpr Kargs
MakeKargs(const void* p_in,
void* p_out,
const long_index_t G,
const long_index_t N,
const long_index_t C,
const array<long_index_t, NDimSpatial> input_spatial_lengths,
const array<long_index_t, NDimSpatial> filter_spatial_lengths,
const array<long_index_t, NDimSpatial> output_spatial_lengths,
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
const array<long_index_t, 3> gemm_g_m_k_strides,
const array<long_index_t, NDimSpatial> conv_filter_strides,
const array<long_index_t, NDimSpatial> conv_filter_dilations,
const array<long_index_t, NDimSpatial> input_left_pads,
const array<long_index_t, NDimSpatial> input_right_pads)
{
return Kargs{p_in,
p_out,
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
{
return dim3(
integer_divide_ceil(GemmM, kMPerBlock), integer_divide_ceil(GemmK, kKPerBlock), Batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
{
static_assert(NDimSpatial == 2, "Not supported.");
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(
kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
make_tuple(kargs.image_g_n_c_wis_strides[I1],
kargs.image_g_n_c_wis_strides[I3],
kargs.image_g_n_c_wis_strides[I4],
kargs.image_g_n_c_wis_strides[I2]),
number<AligmentIn>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(kargs.N),
make_pad_transform(kargs.input_spatial_lengths[I0],
kargs.input_left_pads[I0],
kargs.input_right_pads[I0]),
make_pad_transform(kargs.input_spatial_lengths[I1],
kargs.input_left_pads[I1],
kargs.input_right_pads[I1]),
make_pass_through_transform(kargs.C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(kargs.N),
make_embed_transform(
make_tuple(kargs.filter_spatial_lengths[I0], kargs.output_spatial_lengths[I0]),
make_tuple(kargs.conv_filter_dilations[I0], kargs.conv_filter_strides[I0])),
make_embed_transform(
make_tuple(kargs.filter_spatial_lengths[I1], kargs.output_spatial_lengths[I1]),
make_tuple(kargs.conv_filter_dilations[I1], kargs.conv_filter_strides[I1])),
make_pass_through_transform(kargs.C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(
kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
make_merge_transform(make_tuple(
kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
{
static_assert(NDimSpatial == 2, "Not supported.");
const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
kargs.output_spatial_lengths[I1]);
const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
kargs.filter_spatial_lengths[I1]);
return make_tuple(M, K);
}
CK_TILE_DEVICE static constexpr auto MakeBlockTileDistribution()
{
using P = typename Problem::BlockShape;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<P::kMWarpPerBlock, P::kMThreadPerWarp, P::kMPerThread>,
sequence<P::kKWarpPerBlock, P::kKThreadPerWarp, P::kKPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
{
const auto [M, K] = CalculateMKDims(kargs);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
const auto image_m_k = make_tensor_view<address_space_enum::global>(
static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(kargs.p_out) + out_offset,
make_tuple(M, K),
make_tuple(kargs.gemm_g_m_k_strides[I1], kargs.gemm_g_m_k_strides[I2]),
number<AligmentOut>{},
I1);
const auto image_m_k_padded =
pad_tensor_view(image_m_k,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
sequence<false, true>{});
const auto gemm_m_k_padded =
pad_tensor_view(gemm_m_k,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
sequence<false, true>{});
constexpr auto dstr = MakeBlockTileDistribution();
const auto image_tile =
make_tile_window(image_m_k_padded,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{iM, iK},
dstr);
auto gemm_tile = make_tile_window(gemm_m_k_padded,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{iM, iK},
dstr);
// load from Global
const auto loaded_tile = load_tile(image_tile);
// save to Global
store_tile(gemm_tile, loaded_tile);
}
CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename InDataType_,
typename OutDataType_,
typename BlockShape_,
index_t NDimSpatial_,
index_t AligmentIn_,
index_t AligmentOut_>
struct BlockImageToColumnProblem
{
using InDataType = remove_cvref_t<InDataType_>;
using OutDataType = remove_cvref_t<OutDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr index_t AligmentIn = AligmentIn_;
static constexpr index_t AligmentOut = AligmentOut_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ThreadTile, // Sequence<...
typename WarpTile, // Sequence<...
typename BlockTile> // Sequence<...
struct TileImageToColumnShape
{
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
};
} // namespace ck_tile
......@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
......@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence<2>>{});
}
template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
{
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
using Lengths = decltype(nDstrSpan.impl_);
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
ck_tile::index_t ret = 1;
int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
ck_tile::static_for<0, Lengths::size(), 1>{}(
[&](auto idx) { ret *= Lengths::template at(idx); });
if(n_per_block_tail_loop > 0)
{
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return ret;
return max_count;
}
template <typename DistributedTensor>
......@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return out_dstr_tensor;
}
template <bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y,
MeanDataType* p_mean,
InvStdDataType* p_invStd,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
// TODO - Optimize tail loop to reduce move_tile_window()
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
const auto beta_n = make_naive_tensor_view<address_space_enum::global>(
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr auto NPerThread = GetNPerThread(xDstr);
ThreadWelford<ComputeDataType, XDataType> thread_welford{
type_convert<int>(NPerThread * N / kNPerBlock)};
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
......@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
{
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_mean, make_tuple(M), number<32>{});
auto mean_block_window =
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
}
if constexpr(kSaveInvStd)
{
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_invStd, make_tuple(M), number<32>{});
auto inv_std_block_window =
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
}
// TODO: Extract normalize pipeline
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock;
ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window});
......@@ -274,17 +243,209 @@ struct Layernorm2dFwd
}
}
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x),
static_cast<const GammaDataType*>(kargs.p_gamma),
static_cast<const BetaDataType*>(kargs.p_beta),
static_cast<YDataType*>(kargs.p_y),
static_cast<MeanDataType*>(kargs.p_mean),
static_cast<InvStdDataType*>(kargs.p_invStd),
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.M,
kargs.N);
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
auto inv_std_block_window = [&]() {
if constexpr(kSaveInvStd)
{
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
if(kargs.N <= kNPerBlock)
OnePassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
else
TwoPassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment