Unverified Commit d71189ff authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents f84e2020 73b67f29
...@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor ...@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return false; return false;
} }
} }
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>()) if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{ {
return false; return false;
} }
......
...@@ -12,7 +12,7 @@ namespace device { ...@@ -12,7 +12,7 @@ namespace device {
// 1d // 1d
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NWGK_GKXC_NWGC() constexpr bool is_NWGC_GKXC_NWGK()
{ {
return is_same_v<InLayout, tensor_layout::convolution::NWGC> && return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
...@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC() ...@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
} }
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNWK_GKXC_GNWC() constexpr bool is_GNWC_GKXC_GNWK()
{ {
return is_same_v<InLayout, tensor_layout::convolution::GNWC> && return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
...@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC() ...@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
} }
// 2d // 2d
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGK_GKYXC_NHWGC() constexpr bool is_NHWGC_GKYXC_NHWGK()
{ {
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> && return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
...@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC() ...@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
} }
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNHWK_GKYXC_GNHWC() constexpr bool is_GNHWC_GKYXC_GNHWK()
{ {
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> && return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>; is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
} }
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_GKYXC_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
// 3d // 3d
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGK_GKZYXC_NDHWGC() constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
{ {
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> && return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
...@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC() ...@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
} }
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNDHWK_GKZYXC_GNDHWC() constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
{ {
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> && return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> && is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
...@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC() ...@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
} }
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC() constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
{ {
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() || return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() || is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>(); is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
} }
template <typename InLayout, typename WeiLayout, typename OutLayout> template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC() constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
{ {
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() || return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() || is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>(); is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
} }
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void> template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout ...@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static constexpr const char* name = "NDHWGC"; static constexpr const char* name = "NDHWGC";
}; };
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct NGCW : public BaseTensorLayout
{
static constexpr const char* name = "NGCW";
};
struct NGCHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCHW";
};
struct NGCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCDHW";
};
// input tensor // input tensor
// strided layout // strided layout
struct G_NW_C : public BaseTensorLayout struct G_NW_C : public BaseTensorLayout
...@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout ...@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static constexpr const char* name = "NDHWGK"; static constexpr const char* name = "NDHWGK";
}; };
struct NGKW : public BaseTensorLayout
{
static constexpr const char* name = "NGKW";
};
struct NGKHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKHW";
};
struct NGKDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKDHW";
};
// output tensor // output tensor
// strided layout // strided layout
struct G_NW_K : public BaseTensorLayout struct G_NW_K : public BaseTensorLayout
......
...@@ -41,6 +41,55 @@ __global__ void ...@@ -41,6 +41,55 @@ __global__ void
elementwise_op); elementwise_op);
} }
template <typename GridwiseElementwiseFunctor,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InDataTypePointerTuple p_in_global_tuple_a,
const InDataTypePointerTuple p_in_global_tuple_b,
const OutDataTypePointerTuple p_out_global_tuple_a,
const OutDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
const index_t a_grid_size)
{
if(get_block_1d_id() < a_grid_size)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_tuple_a,
p_out_global_tuple_a,
block_2_tile_map_a,
elementwise_op,
get_block_1d_id());
}
else
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_tuple_b,
p_out_global_tuple_b,
block_2_tile_map_b,
elementwise_op,
get_block_1d_id() - a_grid_size);
}
}
template <typename GridwiseElementwiseFunctor, template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple, typename InGridDescTuple,
typename OutGridDescTuple, typename OutGridDescTuple,
...@@ -133,7 +182,8 @@ struct GridwiseElementwise ...@@ -133,7 +182,8 @@ struct GridwiseElementwise
const InDataTypePointerTuple& p_in_global_tuple, const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple, const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map, const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op) const ElementwiseOperation& elementwise_op,
const index_t block_id = get_block_1d_id())
{ {
constexpr auto src_datas = generate_tuple( constexpr auto src_datas = generate_tuple(
...@@ -169,7 +219,7 @@ struct GridwiseElementwise ...@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number<NumOutput>{}); Number<NumOutput>{});
const auto block_work_idx = const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m0_block_data_idx_on_grid = const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
......
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...@@ -156,6 +157,14 @@ ...@@ -156,6 +157,14 @@
#endif #endif
#endif #endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG #ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0 #define CK_TILE_DEBUG_LOG 0
#endif #endif
......
...@@ -17,6 +17,7 @@ enum class bf16_rounding_mode ...@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard = 0, // rtn standard = 0, // rtn
truncate_with_nan, truncate_with_nan,
truncate, truncate,
standard_asm,
}; };
template <bf16_rounding_mode rounding = template <bf16_rounding_mode rounding =
...@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f) ...@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
CK_TILE_HOST
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rtn_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
uint32_t tmp;
asm volatile("\n \
v_cmp_u_f32 %0, %2, %2 \n \
v_bfe_u32 %1, %2, 16, 1 \n \
v_add3_u32 %1, %2, %1, %3 \n \
v_cndmask_b32 %2, %1, %4, %0 \n \
v_lshrrev_b32 %2, 16, %2 \n \
"
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
return uint16_t(u.int32);
}
// Truncate instead of rounding, preserving SNaN // Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f) constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
...@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round ...@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{ {
if constexpr(rounding == bf16_rounding_mode::standard) if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f); return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f); return float_to_bf16_truc_nan_raw(f);
else else
......
...@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); }; ...@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST CK_TILE_HOST
float log(float x) { return std::logf(x); }; float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
{ {
// TODO: this is hacky, we use u16
return __builtin_amdgcn_sad_u16(x, y, acc); return __builtin_amdgcn_sad_u16(x, y, acc);
} }
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t res;
asm volatile("v_sad_u32 %0, %1, %2, %3" : "=v"(res) : "v"(x), "v"(y), "v"(acc));
return res;
}
CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
{ {
return (x > y ? (x - y) : (y - x)) + acc; return (x > y ? (x - y) : (y - x)) + acc;
} }
......
...@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution ...@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate // move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
...@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution ...@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{}, bool_constant<oob_conditional_check>{},
pre_nop_); pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile( asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif #endif
...@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths ...@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
}
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move window-origin // move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
...@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view, ...@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin}; tensor_view, window_lengths, origin};
} }
// duplicate tile window and replace its origin
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin)
{
return tile_window_with_static_lengths<TensorView, WindowLengths>{
tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
}
template <typename TensorView_, typename WindowLengths_> template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window( CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window, tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
......
...@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; ...@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T> template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type; using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename From, typename To>
struct copy_const
{
static_assert(!std::is_const_v<From>);
using type = To;
};
template <typename From, typename To>
struct copy_const<const From, To>
{
using type = std::add_const_t<typename copy_const<From, To>::type>;
};
template <typename From, typename To>
using copy_const_t = typename copy_const<From, To>::type;
namespace detail { namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args> template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector struct detector
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
......
...@@ -155,7 +155,12 @@ struct HostTensorDescriptor ...@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return space; return space;
} }
std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
const std::vector<std::size_t>& get_lengths() const { return mLens; } const std::vector<std::size_t>& get_lengths() const { return mLens; }
std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
const std::vector<std::size_t>& get_strides() const { return mStrides; } const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is> template <typename... Is>
...@@ -325,8 +330,12 @@ struct HostTensor ...@@ -325,8 +330,12 @@ struct HostTensor
{ {
} }
std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
decltype(auto) get_lengths() const { return mDesc.get_lengths(); } decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
decltype(auto) get_strides() const { return mDesc.get_strides(); } decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
......
...@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) ...@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{ {
// clang-format off // clang-format off
if(!s.time_kernel_) { if(!s.time_kernel_) {
(callables(s),...); hip_check_error(hipGetLastError()); (callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
return 0; return 0;
} }
if(s.is_gpu_timer_) { if(s.is_gpu_timer_) {
gpu_timer timer {}; gpu_timer timer {};
// warmup // warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_); timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_); timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_; return timer.duration() / s.nrepeat_;
...@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) ...@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer timer {}; cpu_timer timer {};
// warmup // warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_); timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_); timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_; return timer.duration() / s.nrepeat_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace ck_tile {
template <typename DataType, typename ComputeDataType = float>
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
bool interleaved,
HostTensor<DataType>& output_bsd,
bool use_1_row_sin_cos = false)
{
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
output_bsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[2];
if(rotary_dim <= i_d)
{
self(i) = input_bsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[1];
const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
: cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
: sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
if(interleaved)
{
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
else
{
const index_t half_rdim = (rotary_dim / 2);
const index_t pos = (i_d + half_rdim) % rotary_dim;
const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});
}
} // namespace ck_tile
...@@ -7,7 +7,11 @@ ...@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...@@ -21,11 +25,11 @@ ...@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
......
...@@ -43,9 +43,12 @@ enum struct AlibiMode ...@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT = 2, FROM_BOTTOM_RIGHT = 2,
}; };
template <typename DataType, bool RowMajor = true> template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
struct Alibi struct Alibi
{ {
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
// RowMajor here means if pixel within the same thread are along the row, or col // RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same. // this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...@@ -79,6 +82,19 @@ struct Alibi ...@@ -79,6 +82,19 @@ struct Alibi
mode = mode_; mode = mode_;
} }
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
if constexpr(LogMaxSadOprndSize <= 16)
{
return sad_u16(
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
}
return sad_u32(x, y, acc);
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{ {
if constexpr(RowMajor) if constexpr(RowMajor)
...@@ -128,7 +144,7 @@ struct EmptyPositionEncoding ...@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate // can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask // if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0 // local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true> template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size, index_t window_left_size,
index_t window_right_size, index_t window_right_size,
...@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, ...@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode alibi_mode = AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/; : static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode}; return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
} }
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 // https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class RotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <RotaryEmbeddingEnum>
struct RotaryEmbeddingEnumToStr;
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
struct BlockRotaryEmbedding
{
template <typename DistributedTensor,
typename OtherDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
OtherDramBlockWindow other_window,
RotaryCosDramBlockWindow rotary_cos_window,
RotarySinDramBlockWindow rotary_sin_window,
index_t rotary_dim,
index_t thread_end)
{
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
if(thread_end <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
});
}
}
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
if(thread_end <= rotary_dim)
{
const bool is_left = (thread_end <= (rotary_dim / 2));
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto other_tile = load_tile(other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
tile.thread_buf_[idx] =
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
});
}
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace ck_tile {
// assume that we have only 1 page-block/tensor view
template <typename TensorView>
struct TrivialPageBlockNavigator
{
using DataType = typename TensorView::DataType;
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
: tensor_view(tensor_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
return make_tuple(/*block_index=*/0,
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE constexpr auto
make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
return make_tuple(
/*block_index=*/0,
ck_tile::make_tile_window(
tensor_view, window_lengths, window_origin, tile_distribution));
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE static index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
{
ck_tile::move_tile_window(tile_window, step);
return /*block_index=*/0;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
return global_window_origin;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
{
return local_window_origin;
}
private:
TensorView tensor_view;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template <typename DataType_, index_t VirtualDim, typename TensorView>
struct PageBlockNavigator
{
using DataType = DataType_;
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
long_index_t block_stride_,
long_index_t fixed_offset_,
const int32_t* physical_block_indices_,
index_t num_blocks_,
index_t page_block_size_,
const TensorView& complete_view_,
const TensorView& last_view_)
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
block_stride(block_stride_),
fixed_offset(fixed_offset_),
physical_block_indices(physical_block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_),
complete_view(complete_view_),
last_view(last_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin,
tile_distribution);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
return new_block_index;
}
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
{
return block_index == num_blocks - 1;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
const TileWindow& tile_window) const
{
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE void
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
{
const multi_index<2> step = [&]() {
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
if constexpr(VirtualDim == 0)
{
return make_multi_index(origin_diff, 0);
}
else
{
return make_multi_index(0, origin_diff);
}
}();
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(tile_window.get_window_origin() + step);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
}
CK_TILE_HOST_DEVICE WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin) const
{
if constexpr(VirtualDim == 0)
{
const index_t length = global_window_origin.at(number<0>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(length - page_block_size * num_complete_blocks,
global_window_origin.at(number<1>{}));
}
else
{
const index_t length = global_window_origin.at(number<1>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(global_window_origin.at(number<0>{}),
length - page_block_size * num_complete_blocks);
}
}
CK_TILE_HOST_DEVICE WindowOrigin
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
{
if constexpr(VirtualDim == 0)
{
return make_multi_index(block_index * page_block_size +
local_window_origin.at(number<0>{}),
local_window_origin.at(number<1>{}));
}
else
{
return make_multi_index(local_window_origin.at(number<0>{}),
block_index * page_block_size +
local_window_origin.at(number<1>{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const
{
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
}
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
{
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
}
DataType* physical_blocks;
long_index_t block_stride;
long_index_t fixed_offset;
const int32_t* physical_block_indices;
index_t num_blocks;
index_t page_block_size;
TensorView complete_view;
TensorView last_view;
};
template <typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
{
return TrivialPageBlockNavigator<TensorView>(tensor_view);
}
template <typename DataType, index_t VirtualDim, typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
long_index_t block_stride,
long_index_t fixed_offset,
const int32_t* physical_block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorView& complete_view,
const TensorView& last_view)
{
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
block_stride,
fixed_offset,
physical_block_indices,
num_blocks,
page_block_size,
complete_view,
last_view);
}
} // namespace ck_tile
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
struct FmhaFwdAppendKVTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static_assert(kK0 == kN1);
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()()
{
const index_t i_tile = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment