Unverified Commit 2cab8d39 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

CK Tile FA Training kernels (#1286)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 76827d82
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
class philox
{
public:
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
: seed(reinterpret_cast<const uint2&>(seed_))
{
ull2* tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset_;
}
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
{
uint4 counter_ = counter;
ull2* tmp = reinterpret_cast<ull2*>(&counter_);
tmp->y = subsequence;
uint2 key_ = seed;
// 7-round philox
#pragma unroll
for(int i = 0; i < 6; i++)
{
counter_ = philox_single_round(counter_, key_);
key_.x += kPhilox10A;
key_.y += kPhilox10B;
}
uint4 output = philox_single_round(counter_, key_);
return output;
}
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out,
const unsigned long long subsequence) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp_ph.x;
out_tmp[1] = tmp_ph.y;
out_tmp[2] = tmp_ph.z;
out_tmp[3] = tmp_ph.w;
}
private:
struct ull2
{
uint64_t x;
uint64_t y;
};
uint4 counter;
const uint2 seed;
CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
{
uint2* res;
unsigned long long tmp;
tmp = static_cast<unsigned long long>(a) * b;
res = reinterpret_cast<uint2*>(&tmp);
return *res;
}
CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
{
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
} // namespace ck_tile
......@@ -11,6 +11,7 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -156,7 +156,7 @@ struct HostTensorDescriptor
}
const std::vector<std::size_t>& get_lengths() const { return mLens; }
const std::vector<std::size_t>& GetStrides() const { return mStrides; }
const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
......@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
for(std::size_t i = 0; i < a.get_num_of_dimension(); i++)
{
new_lengths[i] = a.get_lengths()[new2old[i]];
new_strides[i] = a.GetStrides()[new2old[i]];
new_strides[i] = a.get_strides()[new2old[i]];
}
return HostTensorDescriptor(new_lengths, new_strides);
......@@ -327,7 +327,7 @@ struct HostTensor
decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
......@@ -481,6 +481,34 @@ struct HostTensor
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
HostTensor<T> transpose(std::vector<size_t> axes = {}) const
{
if(axes.empty())
{
axes.resize(this->get_num_of_dimension());
std::iota(axes.rbegin(), axes.rend(), 0);
}
if(axes.size() != mDesc.get_num_of_dimension())
{
throw std::runtime_error(
"HostTensor::transpose(): size of axes must match tensor dimension");
}
std::vector<size_t> tlengths, tstrides;
for(const auto& axis : axes)
{
tlengths.push_back(get_lengths()[axis]);
tstrides.push_back(get_strides()[axis]);
}
HostTensor<T> ret(*this);
ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
return ret;
}
HostTensor<T> transpose(std::vector<size_t> axes = {})
{
return const_cast<HostTensor<T> const*>(this)->transpose(axes);
}
typename Data::iterator begin() { return mData.begin(); }
typename Data::iterator end() { return mData.end(); }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename DataType, typename RandValOutputDataType>
CK_TILE_HOST void reference_batched_dropout(HostTensor<DataType>& in_out_b_m_n,
const HostTensor<RandValOutputDataType>& randval_b_m_n,
const uint8_t& p_undrop_in_uint8_t,
const float scale)
{
const int N = in_out_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
float tmp = ck_tile::type_convert<float>(in_out_b_m_n(batch, m, n)) * scale;
in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t
? ck_tile::type_convert<DataType>(tmp)
: DataType(0);
}
};
make_ParallelTensorFunctor(
f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile
......@@ -4,10 +4,24 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
struct BlockDropout
{
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_,
bool is_store_randval_)
: ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
is_store_randval(is_store_randval_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
}
else
{
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
auto randval_lds_read_window =
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
randval_lds_window.get_window_lengths(),
randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
});
});
// save to Global
if(is_store_randval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
});
if(is_store_randval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
});
if(is_store_randval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval.kThreadElementSpaceSize == 16);
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{});
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start = (start_m0_idx / WG::kM) + i_m0;
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t[i_random_idx++];
constexpr auto p_idx0 =
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx]
: -p_compute[p_idx];
});
});
// save to Global
if(is_store_randval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0});
}
});
if(is_store_randval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
}
});
if(is_store_randval)
{
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
}
}
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -141,6 +141,36 @@ struct GenericAttentionMask
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, y_total);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();
return ck_tile::make_tuple(y_start, y_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
......@@ -160,14 +190,14 @@ struct GenericAttentionMask
}
else
{
return i_x >= x_end;
return i_x >= x_end || i_y >= y_total;
}
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
......@@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, y_total);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();
return ck_tile::make_tuple(y_start, y_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
......@@ -283,13 +343,13 @@ struct SimplifiedGenericAttentionMask
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end;
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
......@@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
}
} // namespace ck_tile
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,11 +9,11 @@
#include <string>
#include <type_traits>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q]
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace ck_tile {
......@@ -32,6 +32,8 @@ struct FmhaFwdKernel
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
......@@ -45,6 +47,7 @@ struct FmhaFwdKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
......@@ -84,7 +87,7 @@ struct FmhaFwdKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
......@@ -111,6 +114,7 @@ struct FmhaFwdKernel
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
......@@ -163,11 +167,35 @@ struct FmhaFwdKernel
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs
struct FmhaFwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_lse = 0;
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
}
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaFwdBatchModeKargs
......@@ -178,8 +206,9 @@ struct FmhaFwdKernel
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
......@@ -196,7 +225,8 @@ struct FmhaFwdKernel
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
......@@ -211,12 +241,14 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
......@@ -225,22 +257,28 @@ struct FmhaFwdKernel
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -250,6 +288,7 @@ struct FmhaFwdKernel
seqlen_k,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
......@@ -268,6 +307,7 @@ struct FmhaFwdKernel
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v,
......@@ -302,6 +342,15 @@ struct FmhaFwdKernel
kargs.scale_p = scale_p;
kargs.scale_o = scale_o;
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
kargs.is_store_randval = s_randval;
}
return kargs;
}
......@@ -312,6 +361,7 @@ struct FmhaFwdKernel
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* rand_val_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
......@@ -319,6 +369,7 @@ struct FmhaFwdKernel
const void* seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
......@@ -327,16 +378,22 @@ struct FmhaFwdKernel
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type)
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
......@@ -346,6 +403,7 @@ struct FmhaFwdKernel
-1, //
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
......@@ -364,6 +422,7 @@ struct FmhaFwdKernel
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
......@@ -389,12 +448,21 @@ struct FmhaFwdKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_p = scale_p;
kargs.scale_o = scale_o;
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.is_store_randval = s_randval;
}
return kargs;
}
......@@ -426,12 +494,13 @@ struct FmhaFwdKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kIsGroupMode)
{
......@@ -455,7 +524,11 @@ struct FmhaFwdKernel
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
if constexpr(kHasDropout)
{
batch_offset_randval = query_start * kargs.stride_randval;
}
batch_offset_o = query_start * kargs.stride_o;
......@@ -493,6 +566,11 @@ struct FmhaFwdKernel
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
if constexpr(kHasDropout)
{
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
......@@ -666,6 +744,62 @@ struct FmhaFwdKernel
}
}();
// dropout
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
bool is_store_randval = false;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
}
BlockDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t,
is_store_randval);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout)
{
RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
batch_offset_randval;
const auto randval_dram = [&]() {
const auto randval_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
rand_val_ptr,
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
make_tuple(kargs.stride_randval, 1),
number<1>{},
number<1>{});
return pad_tensor_view(randval_dram_naive,
randval_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}();
return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(randval_dram_window_lengths);
}
}();
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
......@@ -723,6 +857,7 @@ struct FmhaFwdKernel
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
......@@ -731,7 +866,8 @@ struct FmhaFwdKernel
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
smem_ptr,
dropout);
}
else
{
......@@ -739,11 +875,13 @@ struct FmhaFwdKernel
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
smem_ptr,
dropout);
}
}();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy>
struct BlockFmhaBwdOGradDotO
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kVHeaddim = Problem::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
template <typename ODramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename DDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
DDramBlockWindowTmp& d_dram_block_window_tmp,
float p_undrop) const
{
static_assert(
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kBlockSize ==
OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
auto o_dram_window =
make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(),
o_dram_block_window_tmp.get_window_lengths(),
o_dram_block_window_tmp.get_window_origin(),
Policy::template MakePreODramTileDistribution<Problem>());
auto o = load_tile(o_dram_window);
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
do_dram_block_window_tmp.get_window_origin(),
Policy::template MakePreOGradDramTileDistribution<Problem>());
auto do_ = load_tile(do_dram_window);
// declare d
constexpr auto d_dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
clear_tile(d); // Initialize D
constexpr auto o_spans = decltype(o)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
d(i_idx) +=
(type_convert<DDataType>(o[i_j_idx]) * type_convert<DDataType>(do_[i_j_idx]));
});
});
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
store_tile(d_dram_block_window_tmp, d);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment