Commit 63b152d6 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_v3

parents ae2d7d2b 14c3cfb1
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
...@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr ...@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return !(a == b); return !(a == b);
} }
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
{
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
template <typename T, index_t N, typename X> template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x) CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{ {
......
...@@ -58,7 +58,7 @@ struct thread_buffer { ...@@ -58,7 +58,7 @@ struct thread_buffer {
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_, template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false> typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const CK_TILE_HOST_DEVICE constexpr auto _get_as() const
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp" #include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp" #include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
......
...@@ -50,12 +50,22 @@ class ArgParser ...@@ -50,12 +50,22 @@ class ArgParser
} }
return *this; return *this;
} }
void print() void print() const
{ {
// find max key length
std::string::size_type max_key_length = 11;
for(auto& key : keys)
{
if(max_key_length < key.length())
{
max_key_length = key.length();
}
}
printf("args:\n"); printf("args:\n");
for(auto& key : keys) for(auto& key : keys)
{ {
auto value = input_map[key]; auto value = input_map.at(key);
std::vector<std::string> help_text_lines; std::vector<std::string> help_text_lines;
size_t pos = 0; size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
...@@ -69,8 +79,7 @@ class ArgParser ...@@ -69,8 +79,7 @@ class ArgParser
std::string(value.help_text.begin() + pos, value.help_text.end())); std::string(value.help_text.begin() + pos, value.help_text.end()));
std::string default_value = std::string("(default:") + value.value + std::string(")"); std::string default_value = std::string("(default:") + value.value + std::string(")");
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value << std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl; << std::endl;
...@@ -78,7 +87,8 @@ class ArgParser ...@@ -78,7 +87,8 @@ class ArgParser
help_next_line != help_text_lines.end(); help_next_line != help_text_lines.end();
++help_next_line) ++help_next_line)
{ {
std::cout << std::setw(17) << " " << *help_next_line << std::endl; std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
<< std::endl;
} }
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
namespace conv {
namespace detail {
template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
{
return {0, 1, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
{
return {2, 0, 3, 1};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
return {4, 0, 5, 1, 2, 3};
}
else
{
printf("%s\n", __func__);
throw std::runtime_error("wrong! unsupported layout");
}
}
} // namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", InLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<InLayout>());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", WeiLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", OutLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
}
} // namespace conv
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace ck_tile {
namespace conv {
struct ConvParam
{
ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count,
ck_tile::index_t n_batch,
ck_tile::index_t n_out_channels,
ck_tile::index_t n_in_channels,
const std::vector<ck_tile::index_t>& filters_len,
const std::vector<ck_tile::index_t>& input_len,
const std::vector<ck_tile::index_t>& strides,
const std::vector<ck_tile::index_t>& dilations,
const std::vector<ck_tile::index_t>& left_pads,
const std::vector<ck_tile::index_t>& right_pads)
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
G_(static_cast<ck_tile::long_index_t>(group_count)),
N_(static_cast<ck_tile::long_index_t>(n_batch)),
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
filter_spatial_lengths_(num_dim_spatial_),
input_spatial_lengths_(num_dim_spatial_),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(num_dim_spatial_),
conv_filter_dilations_(num_dim_spatial_),
input_left_pads_(num_dim_spatial_),
input_right_pads_(num_dim_spatial_)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ConvParam(ck_tile::long_index_t n_dim,
ck_tile::long_index_t group_count,
ck_tile::long_index_t n_batch,
ck_tile::long_index_t n_out_channels,
ck_tile::long_index_t n_in_channels,
const std::vector<ck_tile::long_index_t>& filters_len,
const std::vector<ck_tile::long_index_t>& input_len,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads,
const std::vector<ck_tile::long_index_t>& right_pads)
: num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch),
K_(n_out_channels),
C_(n_in_channels),
filter_spatial_lengths_(filters_len),
input_spatial_lengths_(input_len),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(strides),
conv_filter_dilations_(dilations),
input_left_pads_(left_pads),
input_right_pads_(right_pads)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ck_tile::long_index_t num_dim_spatial_;
ck_tile::long_index_t G_;
ck_tile::long_index_t N_;
ck_tile::long_index_t K_;
ck_tile::long_index_t C_;
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
std::vector<ck_tile::long_index_t> conv_filter_strides_;
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
std::vector<ck_tile::long_index_t> input_left_pads_;
std::vector<ck_tile::long_index_t> input_right_pads_;
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
{
return output_spatial_lengths_;
}
std::size_t GetFlops() const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>());
}
template <typename InDataType>
std::size_t GetInputByte() const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename WeiDataType>
std::size_t GetWeightByte() const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename OutDataType>
std::size_t GetOutputByte() const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
GetOutputByte<OutDataType>();
}
};
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{
std::string msg;
msg += "Following arguments (depending on number of spatial dims):\n"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
" G, N, K, C, \n"
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
" <strides>, (ie Sy, Sx for 2D)\n"
" <dilations>, (ie Dy, Dx for 2D)\n"
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
" <right padding>, (ie RightPy, RightPx for 2D)\n";
return msg;
}
CK_TILE_HOST ck_tile::conv::ConvParam
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stol(argv[arg_idx++]);
}
return ck_tile::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
} // namespace conv
} // namespace ck_tile
...@@ -176,7 +176,20 @@ struct HostTensorDescriptor ...@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
} }
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
{
os << "dim " << desc.get_num_of_dimension() << ", ";
os << "lengths {";
LogRange(os, desc.get_lengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.get_strides(), ", ");
os << "}";
return os;
}
private: private:
std::vector<std::size_t> mLens; std::vector<std::size_t> mLens;
......
...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {}, const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1] ? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0]; : a_m_k.mDesc.get_lengths()[0];
...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k)) ? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m)); : a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b); ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
} }
}; };
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A, __global__ void naive_gemm_kernel(ADataType* A,
BDataType* B, BDataType* B,
CDataType* C, CDataType* C,
...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N) if(row < M && col < N)
{ {
AccDataType acc = 0.0; AccDataType acc = 0.0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
acc += static_cast<AccDataType>(A[row * strideA + k]) * // Adjust indexing based on matrix layout
static_cast<AccDataType>(B[col * strideB + k]); int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
} }
C[row * strideC + col] = acc; // Store as AccDataType int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
} }
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device, DeviceMem& b_device,
DeviceMem& c_device, DeviceMem& c_device,
...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy( errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,53 +9,125 @@ ...@@ -9,53 +9,125 @@
namespace ck_tile { namespace ck_tile {
template <typename T> template <typename InDataType, typename OutDataType, index_t NDimSpatial>
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref, CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
const HostTensor<T>& in_host, HostTensor<OutDataType>& out_host,
int /*N*/, const ck_tile::conv::ConvParam& conv_params)
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
{ {
int GemmM = in_mtx_host_ref.get_lengths()[0]; const long_index_t G = in_host.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1]; const long_index_t N = in_host.get_lengths()[1];
const long_index_t C = in_host.get_lengths()[2];
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m) if constexpr(NDimSpatial == 1)
{ {
int mtmp = gemm_m; const long_index_t Wo = conv_params.output_spatial_lengths_[0];
int n = mtmp / (Ho * Wo); auto func = [&](auto g, auto n, auto wo) {
mtmp -= n * Ho * Wo; long_index_t row = n * Wo + wo;
int ho = mtmp / Wo; long_index_t column = 0;
int wo = mtmp - ho * Wo;
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k) {
{ auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
int ktmp = gemm_k; static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
int y = ktmp / (X * C); static_cast<long_index_t>(conv_params.input_left_pads_[0]);
ktmp -= y * X * C;
int x = ktmp / C; for(long_index_t c = 0; c < C; ++c)
int c = ktmp - x * C; {
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH; {
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW; InDataType v_in = in_host(g, n, c, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi); }
column++;
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0; }
} }
};
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) {
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
{
InDataType v_in = in_host(g, n, c, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t Do = conv_params.output_spatial_lengths_[0];
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
hi >= 0 &&
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
{
InDataType v_in = in_host(g, n, c, di, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
} }
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#pragma once #pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define CK_TILE_MAX_RANK 5
namespace ck_tile {
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool kTilePermute_,
index_t kRank_,
index_t kPerm0,
index_t kPerm1,
index_t TileSize0,
index_t TileSize1,
index_t kPerm2 = 0,
index_t kPerm3 = 0,
index_t kPerm4 = 0,
index_t TileSize2 = 0,
index_t TileSize3 = 0,
index_t TileSize4 = 0>
struct CShuffleEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTilePermute = kTilePermute_;
static constexpr index_t kRank = kRank_;
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
const index_t* kPerm = Problem::kPerm;
static constexpr bool kTilePermute = Problem::kTilePermute;
static constexpr index_t kRank = Problem::kRank;
const index_t* tile_sizes = Problem::tile_sizes;
// No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{
using DataType = typename OAccTile::DataType;
// Get thread buffer
auto& thread_buf = o_acc_tile.get_thread_buffer();
// Create a temporary buffer to hold the permuted data
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
// Get the lengths of each dimension
auto thread_tensor_lengths = o_acc_tile.get_lengths();
// Total number of elements
index_t total_elements = OAccTile::kThreadElementSpaceSize;
// Iterate over all elements
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
{
// Convert linear index to multi-dimensional indices
array<index_t, kRank> indices;
index_t remaining = linear_idx;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
remaining /= thread_tensor_lengths.get(number<rev_i>{});
});
// Apply the permutation
array<index_t, kRank> permuted_indices;
static_for<0, kRank, 1>{}(
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
// Compute offsets
index_t dst_offset = 0;
index_t stride = 1;
static_for<0, kRank, 1>{}([&](auto i) {
constexpr auto rev_i = kRank - 1 - i;
dst_offset += permuted_indices[rev_i] * stride;
stride *= thread_tensor_lengths.get(number<rev_i>{});
});
// Move the data
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
}
// Copy the permuted data back to the original thread buffer
for(index_t i = 0; i < total_elements; ++i)
{
thread_buf.set_as(i, permuted_thread_buf.get(i));
}
}
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
for(index_t i = 0; i < kRank; ++i)
{
permuted_tile_coords[i] = tile_coords[kPerm[i]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
for(index_t i = 0; i < kRank; ++i)
{
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename ODramWindowTmp::BottomTensorIndex step = {};
for(index_t i = 0; i < kRank; ++i)
{
step[i] = permuted_window_origin[i] - current_window_origin[i];
}
// Move the window
move_tile_window(o_dram_window_tmp, step);
// Permute the data within the tile if necessary
if constexpr(kTilePermute)
{
permute_tile_data(o_acc_tile);
}
// Store the tile data to the permuted location
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
} // namespace ck_tile
...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask ...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{ {
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits); const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); const index_t split_end = split_start + x_per_split;
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end)); ck_tile::min(origin_end, split_end));
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel ...@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type; ck_tile::GenericAttentionMaskEnum mask_type;
}; };
struct FmhaBwdCommonDropoutKargs struct FmhaBwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset, union ValueOrPointer
const float raw_scale) {
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale; scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs struct FmhaBwdDeterministicKargs
{ {
ck_tile::index_t split_stride_dq_acc = 0; ck_tile::index_t split_stride_dq_acc = 0;
...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval) if constexpr(kIsStoreRandval)
{ {
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_, return FmhaDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t}; kargs.p_undrop_in_uint8_t};
} }
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...@@ -170,29 +173,55 @@ struct FmhaFwdKernel ...@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_lse = 0;
}; };
struct FmhaFwdCommonDropoutKargs struct FmhaFwdDropoutSeedOffset
{ {
void init_dropout(const float p_drop, template <typename T>
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{ {
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t = p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max())); uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop; rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset); this->drop_seed.ptr = seed_ptr;
drop_offset = std::get<1>(drop_seed_offset); this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
} }
float rp_undrop = 1; float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false; bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0; ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0; ck_tile::index_t nhead_stride_randval = 0;
}; };
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
...@@ -278,7 +307,8 @@ struct FmhaFwdKernel ...@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -344,7 +374,19 @@ struct FmhaFwdKernel ...@@ -344,7 +374,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -392,7 +434,8 @@ struct FmhaFwdKernel ...@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -455,7 +498,19 @@ struct FmhaFwdKernel ...@@ -455,7 +498,19 @@ struct FmhaFwdKernel
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset); if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr; kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval; kargs.nhead_stride_randval = nhead_stride_randval;
...@@ -748,8 +803,10 @@ struct FmhaFwdKernel ...@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_, return BlockDropout{i_batch_,
i_nhead_, i_nhead_,
kargs.num_head_q, kargs.num_head_q,
kargs.drop_seed, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
kargs.drop_offset, : *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop, kargs.rp_undrop,
kargs.p_undrop_in_uint8_t, kargs.p_undrop_in_uint8_t,
kargs.is_store_randval}; kargs.is_store_randval};
......
...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void* o_ptr; void* o_ptr;
ck_tile::index_t batch; ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v; ck_tile::index_t hdim_v;
ck_tile::index_t num_splits; ck_tile::index_t num_splits;
...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>, std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
seqlen_q, seqlen_q,
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o, batch_stride_lse_acc,
batch_stride_lse_acc}; batch_stride_o_acc,
batch_stride_o};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer -1, // seqlen will be updated by another pointer
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0; long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0; long_index_t batch_offset_o = 0;
...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start; batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = query_start; batch_offset_lse = query_start;
} }
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_dram = [&]() { auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q = const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v = const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view( return transform_tensor_view(
o_acc_dram_view, o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_pass_through_transform(padded_hdim_v)), make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
else else
...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window, o_acc_dram_window,
lse_dram_window, lse_dram_window,
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
}(); }();
......
...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner ...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead_, nhead,
batch_size_); batch_size);
} }
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{ {
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x; const index_t i_block = blockIdx.x;
......
...@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel ...@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel ...@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel ...@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v; // only used for paged-kvcache
}; };
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
...@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel ...@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
...@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel ...@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(seqlen_k_ptr), reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc,
batch_stride_o_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel ...@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k, // only used for paged-kvcache
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
...@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel ...@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
...@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel ...@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel ...@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc = long_index_t batch_offset_o_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
...@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel ...@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel ...@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.stride_o_acc;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
...@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel ...@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel ...@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.hdim_v, 1), make_tuple(kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentO>{}, number<1>{},
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(
......
...@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner ...@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits, nhead * num_splits,
batch_size); batch_size);
......
...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>()); k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window = auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window = auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}), make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(), k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>( auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
...@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>()); v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window = auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0}); make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window = auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}), make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(), v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>()); Policy::template MakeVRegBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------ //------------------------------------------------------------------
// KT, Reg ->LDS ->Reg // KT, Reg ->LDS ->Reg
...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window( auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>( auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds(); block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window); auto v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds(); block_sync_lds();
//---------------------------- Loop Load in ----------------------------// //---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS // Q: HBM ->Reg ->LDS
...@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>()); q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window = auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window = auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(), make_tile_window(q_lds_window.get_bottom_tensor_view(),
...@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window( auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>( auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
...@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>()); do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window = auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window = auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(), make_tile_window(do_lds_window.get_bottom_tensor_view(),
...@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window( auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>( auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
...@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t i_total_loops = 0; index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start; index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1"); static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3"); static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
......
...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>()); k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window = auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window = auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}), make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(), k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>( auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
...@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>()); v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window = auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0}); make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window = auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}), make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(), v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>()); Policy::template MakeVRegBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------ //------------------------------------------------------------------
// KT, Reg ->LDS ->Reg // KT, Reg ->LDS ->Reg
...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window( auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>( auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds(); block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window); auto v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------// //---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS // Q: HBM ->Reg ->LDS
auto q_dram_window = auto q_dram_window =
...@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>()); q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window = auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window = auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(), make_tile_window(q_lds_window.get_bottom_tensor_view(),
...@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window( auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>( auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
...@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>()); do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window = auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window = auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(), make_tile_window(do_lds_window.get_bottom_tensor_view(),
...@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window( auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>( auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
...@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t i_total_loops = 0; index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start; index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1"); static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3"); static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
...@@ -827,6 +824,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -827,6 +824,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}, },
s_acc, s_acc,
bias_s_tile); bias_s_tile);
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -918,6 +916,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -918,6 +916,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>(); HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{}; auto dp_acc = SPGradBlockTileType{};
...@@ -927,6 +926,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -927,6 +926,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>(); HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{}; auto ds = SPGradBlockTileType{};
...@@ -965,6 +965,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -965,6 +965,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile); store_tile(dbias_dram_window, dbias_tile);
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
...@@ -984,6 +985,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -984,6 +985,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window(ds_lds_read_window, {0, kK4}); move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>(); HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 7, SGrad@K^T Gemm4 // STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{}; auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); clear_tile(dq_acc);
...@@ -1005,6 +1007,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -1005,6 +1007,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}); });
HotLoopScheduler::template GemmStagedScheduler<4>(); HotLoopScheduler::template GemmStagedScheduler<4>();
__builtin_amdgcn_sched_barrier(0);
// Results Scale // Results Scale
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment