Commit e536d321 authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 829e0eb3 52410b49
......@@ -102,10 +102,9 @@ __global__ void
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t e_batch_offset =
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
......@@ -118,14 +117,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
if constexpr(isMultiA || isMultiB)
{
AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp;
const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
......@@ -136,27 +135,27 @@ __global__ void
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}([&](auto i) {
p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i];
p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i];
});
}
else
{
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_for<0, 1, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; });
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; });
}
const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}(
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid_grp,
p_bs_grid_grp,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_e_grid + e_group_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -169,19 +168,19 @@ __global__ void
}
else
{
const long_index_t a_batch_offset =
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset + a_n_offset,
p_bs_grid + b_batch_offset,
p_as_grid + a_group_offset + a_n_offset,
p_bs_grid + b_group_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_e_grid + e_group_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -283,7 +282,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
......@@ -302,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
static_assert(NumGroupsToMerge >= 1);
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
......@@ -318,7 +320,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
EDataType,
NumGroupsToMerge>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......@@ -517,7 +520,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideA_(i) =
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
......@@ -545,7 +549,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_(i) =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
......@@ -565,8 +570,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers
......@@ -583,7 +590,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
......@@ -602,7 +610,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
// populate desc for Ds/E
......@@ -726,7 +734,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_;
const index_t gdy = arg.num_group_ / NumGroupsToMerge;
const index_t gdz = num_workgroups_per_Conv_N;
const auto K =
......@@ -850,6 +858,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
// check device
if(get_device_name() == "gfx908")
{
......@@ -898,6 +910,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
}
if constexpr(NumGroupsToMerge > 1)
{
if(!(C == 1))
{
return false;
}
if(G % NumGroupsToMerge != 0)
{
return false;
}
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
}
// check vector access of A
// FIXME: layout
......@@ -907,11 +955,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
const index_t C = arg.a_g_n_c_wis_lengths_[2];
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
}
else
......@@ -928,8 +981,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
......@@ -953,8 +1004,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
{
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
......@@ -999,8 +1048,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
......@@ -1298,7 +1345,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<< BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< CShuffleNXdlPerWavePerShuffle << ", "
<< NumGroupsToMerge
<< ">";
// clang-format on
......
......@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
......
......@@ -12,7 +12,7 @@ namespace device {
// 1d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NWGK_GKXC_NWGC()
constexpr bool is_NWGC_GKXC_NWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
......@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNWK_GKXC_GNWC()
constexpr bool is_GNWC_GKXC_GNWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
......@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGK_GKYXC_NHWGC()
constexpr bool is_NHWGC_GKYXC_NHWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
......@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNHWK_GKYXC_GNHWC()
constexpr bool is_GNHWC_GKYXC_GNHWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_GKYXC_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
// 3d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
{
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
......@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
{
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
......@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC()
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
{
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>();
return is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>();
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC()
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
{
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>();
return is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static constexpr const char* name = "NDHWGC";
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct NGCW : public BaseTensorLayout
{
static constexpr const char* name = "NGCW";
};
struct NGCHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCHW";
};
struct NGCDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGCDHW";
};
// input tensor
// strided layout
struct G_NW_C : public BaseTensorLayout
......@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static constexpr const char* name = "NDHWGK";
};
struct NGKW : public BaseTensorLayout
{
static constexpr const char* name = "NGKW";
};
struct NGKHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKHW";
};
struct NGKDHW : public BaseTensorLayout
{
static constexpr const char* name = "NGKDHW";
};
// output tensor
// strided layout
struct G_NW_K : public BaseTensorLayout
......
......@@ -41,6 +41,55 @@ __global__ void
elementwise_op);
}
template <typename GridwiseElementwiseFunctor,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InDataTypePointerTuple p_in_global_tuple_a,
const InDataTypePointerTuple p_in_global_tuple_b,
const OutDataTypePointerTuple p_out_global_tuple_a,
const OutDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
const index_t a_grid_size)
{
if(get_block_1d_id() < a_grid_size)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_tuple_a,
p_out_global_tuple_a,
block_2_tile_map_a,
elementwise_op,
get_block_1d_id());
}
else
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_tuple_b,
p_out_global_tuple_b,
block_2_tile_map_b,
elementwise_op,
get_block_1d_id() - a_grid_size);
}
}
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
......@@ -133,7 +182,8 @@ struct GridwiseElementwise
const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op)
const ElementwiseOperation& elementwise_op,
const index_t block_id = get_block_1d_id())
{
constexpr auto src_datas = generate_tuple(
......@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number<NumOutput>{});
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
......
......@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
......@@ -156,6 +157,14 @@
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
......
......@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard = 0, // rtn
truncate_with_nan,
truncate,
standard_asm,
};
template <bf16_rounding_mode rounding =
......@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return uint16_t(u.int32 >> 16);
}
CK_TILE_HOST
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rtn_asm(float f)
{
union
{
float fp32;
uint32_t int32;
} u = {f};
static constexpr uint32_t FP32_NAN = 0x7fff0000;
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
uint32_t tmp;
asm volatile("\n \
v_cmp_u_f32 %0, %2, %2 \n \
v_bfe_u32 %1, %2, 16, 1 \n \
v_add3_u32 %1, %2, %1, %3 \n \
v_cndmask_b32 %2, %1, %4, %0 \n \
v_lshrrev_b32 %2, 16, %2 \n \
"
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
return uint16_t(u.int32);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
......@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else
......
......@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
{
// TODO: this is hacky, we use u16
return __builtin_amdgcn_sad_u16(x, y, acc);
}
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t res;
asm volatile("v_sad_u32 %0, %1, %2, %3" : "=v"(res) : "v"(x), "v"(y), "v"(acc));
return res;
}
CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
{
return (x > y ? (x - y) : (y - x)) + acc;
}
......
......@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
......@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
......@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
}
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
......@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin};
}
// duplicate tile window and replace its origin
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin)
{
return tile_window_with_static_lengths<TensorView, WindowLengths>{
tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
......
......@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename From, typename To>
struct copy_const
{
static_assert(!std::is_const_v<From>);
using type = To;
};
template <typename From, typename To>
struct copy_const<const From, To>
{
using type = std::add_const_t<typename copy_const<From, To>::type>;
};
template <typename From, typename To>
using copy_const_t = typename copy_const<From, To>::type;
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
......
......@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
......
......@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return space;
}
std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
const std::vector<std::size_t>& get_lengths() const { return mLens; }
std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is>
......@@ -325,8 +330,12 @@ struct HostTensor
{
}
std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
......
......@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{
// clang-format off
if(!s.time_kernel_) {
(callables(s),...); hip_check_error(hipGetLastError());
(callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
return 0;
}
if(s.is_gpu_timer_) {
gpu_timer timer {};
// warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
......@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer timer {};
// warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace ck_tile {
template <typename DataType, typename ComputeDataType = float>
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
bool interleaved,
HostTensor<DataType>& output_bsd,
bool use_1_row_sin_cos = false)
{
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
output_bsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[2];
if(rotary_dim <= i_d)
{
self(i) = input_bsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[1];
const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
: cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
: sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
if(interleaved)
{
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
else
{
const index_t half_rdim = (rotary_dim / 2);
const index_t pos = (i_d + half_rdim) % rotary_dim;
const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});
}
} // namespace ck_tile
......@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
......@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
......
......@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT = 2,
};
template <typename DataType, bool RowMajor = true>
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
struct Alibi
{
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
......@@ -79,6 +82,19 @@ struct Alibi
mode = mode_;
}
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
if constexpr(LogMaxSadOprndSize <= 16)
{
return sad_u16(
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
}
return sad_u32(x, y, acc);
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{
if constexpr(RowMajor)
......@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true>
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size,
index_t window_right_size,
......@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode};
return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class RotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <RotaryEmbeddingEnum>
struct RotaryEmbeddingEnumToStr;
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
struct BlockRotaryEmbedding
{
template <typename DistributedTensor,
typename OtherDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
OtherDramBlockWindow other_window,
RotaryCosDramBlockWindow rotary_cos_window,
RotarySinDramBlockWindow rotary_sin_window,
index_t rotary_dim,
index_t thread_end)
{
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
if(thread_end <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
});
}
}
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
if(thread_end <= rotary_dim)
{
const bool is_left = (thread_end <= (rotary_dim / 2));
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto other_tile = load_tile(other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
tile.thread_buf_[idx] =
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
});
}
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace ck_tile {
// assume that we have only 1 page-block/tensor view
template <typename TensorView>
struct TrivialPageBlockNavigator
{
using DataType = typename TensorView::DataType;
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
: tensor_view(tensor_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
return make_tuple(/*block_index=*/0,
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE constexpr auto
make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
return make_tuple(
/*block_index=*/0,
ck_tile::make_tile_window(
tensor_view, window_lengths, window_origin, tile_distribution));
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE static index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
{
ck_tile::move_tile_window(tile_window, step);
return /*block_index=*/0;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
return global_window_origin;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
{
return local_window_origin;
}
private:
TensorView tensor_view;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template <typename DataType_, index_t VirtualDim, typename TensorView>
struct PageBlockNavigator
{
using DataType = DataType_;
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
long_index_t block_stride_,
long_index_t fixed_offset_,
const int32_t* physical_block_indices_,
index_t num_blocks_,
index_t page_block_size_,
const TensorView& complete_view_,
const TensorView& last_view_)
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
block_stride(block_stride_),
fixed_offset(fixed_offset_),
physical_block_indices(physical_block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_),
complete_view(complete_view_),
last_view(last_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin,
tile_distribution);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
return new_block_index;
}
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
{
return block_index == num_blocks - 1;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
const TileWindow& tile_window) const
{
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE void
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
{
const multi_index<2> step = [&]() {
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
if constexpr(VirtualDim == 0)
{
return make_multi_index(origin_diff, 0);
}
else
{
return make_multi_index(0, origin_diff);
}
}();
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(tile_window.get_window_origin() + step);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
}
CK_TILE_HOST_DEVICE WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin) const
{
if constexpr(VirtualDim == 0)
{
const index_t length = global_window_origin.at(number<0>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(length - page_block_size * num_complete_blocks,
global_window_origin.at(number<1>{}));
}
else
{
const index_t length = global_window_origin.at(number<1>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(global_window_origin.at(number<0>{}),
length - page_block_size * num_complete_blocks);
}
}
CK_TILE_HOST_DEVICE WindowOrigin
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
{
if constexpr(VirtualDim == 0)
{
return make_multi_index(block_index * page_block_size +
local_window_origin.at(number<0>{}),
local_window_origin.at(number<1>{}));
}
else
{
return make_multi_index(local_window_origin.at(number<0>{}),
block_index * page_block_size +
local_window_origin.at(number<1>{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const
{
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
}
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
{
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
}
DataType* physical_blocks;
long_index_t block_stride;
long_index_t fixed_offset;
const int32_t* physical_block_indices;
index_t num_blocks;
index_t page_block_size;
TensorView complete_view;
TensorView last_view;
};
template <typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
{
return TrivialPageBlockNavigator<TensorView>(tensor_view);
}
template <typename DataType, index_t VirtualDim, typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
long_index_t block_stride,
long_index_t fixed_offset,
const int32_t* physical_block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorView& complete_view,
const TensorView& last_view)
{
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
block_stride,
fixed_offset,
physical_block_indices,
num_blocks,
page_block_size,
complete_view,
last_view);
}
} // namespace ck_tile
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
struct FmhaFwdAppendKVTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static_assert(kK0 == kN1);
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()()
{
const index_t i_tile = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment