Unverified Commit f84e2020 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents 408534d4 25935b57
...@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported() ...@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}
inline bool is_gfx101_supported() inline bool is_gfx101_supported()
{ {
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
......
...@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -69,17 +69,17 @@ template <typename ALayout, ...@@ -69,17 +69,17 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(arg_.KBatch > 1)
{ hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
if(arg_.KBatch > 1) 0,
hipGetErrorString( arg_.M * arg_.N * sizeof(CDataType),
hipMemsetAsync(arg_.p_c_grid, stream_config.stream_id_));
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -234,38 +230,161 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -234,38 +230,161 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
// Tail number could be One to Seven // Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::One>;
Run(kernel); Run(kernel);
} }
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full) TailNumber::Full)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Full>;
Run(kernel); Run(kernel);
} }
...@@ -273,12 +392,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -273,12 +392,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Two>; TailNumber::Two>;
Run(kernel); Run(kernel);
} }
} }
...@@ -288,12 +407,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -288,12 +407,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three) TailNumber::Three)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Three>; TailNumber::Three>;
Run(kernel); Run(kernel);
} }
} }
...@@ -303,12 +422,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -303,12 +422,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four) TailNumber::Four)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Four>; TailNumber::Four>;
Run(kernel); Run(kernel);
} }
} }
...@@ -318,12 +437,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -318,12 +437,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five) TailNumber::Five)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Five>; TailNumber::Five>;
Run(kernel); Run(kernel);
} }
} }
...@@ -332,12 +451,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -332,12 +451,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Six>; TailNumber::Six>;
Run(kernel); Run(kernel);
} }
} }
...@@ -347,12 +466,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -347,12 +466,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven) TailNumber::Seven)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Seven>; TailNumber::Seven>;
Run(kernel); Run(kernel);
} }
} }
...@@ -361,51 +480,98 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -361,51 +480,98 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
} }
else else
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Odd>; TailNumber::Odd>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy, minimum_occupancy,
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
...@@ -416,12 +582,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -416,12 +582,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
...@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
...@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
1, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op};
...@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
...@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
1, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
......
...@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(arg_.KBatch > 1)
{ hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
if(arg_.KBatch > 1) 0,
hipGetErrorString( arg_.M * arg_.N * sizeof(CDataType),
hipMemsetAsync(arg_.p_c_grid, stream_config.stream_id_));
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
else else
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(arg.KBatch > 1)
{ hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
if(arg.KBatch > 1) 0,
hipGetErrorString(hipMemsetAsync(arg.p_c_grid, arg.M * arg.N * sizeof(CDataType),
0, stream_config.stream_id_));
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
...@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = true,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy>;
Run(kernel);
}
} }
else else
{ {
...@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::Two>;
Run(kernel); Run(kernel);
} }
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == }
TailNumber::Full)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Three>;
Run(kernel); Run(kernel);
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two) TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Four) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Four>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Five) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Five>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Six) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Six>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Seven) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Seven>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
} }
} }
} }
...@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Odd>;
minimum_occupancy, Run(kernel);
TailNumber::Odd>; }
Run(kernel); else
} {
else const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Even>;
minimum_occupancy, Run(kernel);
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Odd>;
minimum_occupancy, Run(kernel);
TailNumber::Odd>; }
Run(kernel); else
} {
else const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Even>;
minimum_occupancy, Run(kernel);
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = false,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
false, minimum_occupancy>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy>;
Run(kernel);
}
} }
else else
{ {
...@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MNKPadding ||
......
...@@ -86,7 +86,6 @@ __global__ void ...@@ -86,7 +86,6 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -101,10 +100,8 @@ __global__ void ...@@ -101,10 +100,8 @@ __global__ void
defined(__gfx94__)) defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t& num_blocks_per_n = groups_count; const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset = const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
...@@ -200,7 +197,6 @@ __global__ void ...@@ -200,7 +197,6 @@ __global__ void
ignore = p_bs_grid; ignore = p_bs_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -321,8 +317,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -321,8 +317,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization, ConvForwardSpecialization,
true /*SplitN*/, true /*SplitN*/,
ALayout, ADataType,
ELayout>; EDataType>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
...@@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; const index_t gdy = arg.num_group_;
const index_t gdz = 1; const index_t gdz = num_workgroups_per_Conv_N;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1, as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1, bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -19,7 +19,7 @@ namespace device { ...@@ -19,7 +19,7 @@ namespace device {
template <index_t Rank, int NumReduceDim> template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths) std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1; long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1; long_index_t reduce_total_length = 1;
...@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& ...@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template <index_t Rank, int NumReduceDim> template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths) std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1; long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1; long_index_t reduce_total_length = 1;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType, ...@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan, PropagateNan,
OutputIndex> OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!"); "Invalid thread cluster size assignments!");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType, ...@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex> OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
...@@ -45,7 +45,7 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType, ...@@ -45,7 +45,7 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
OutElementwiseOperation> OutElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
...@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp ...@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2 unary_op2_{}; UnaryOp2 unary_op2_{};
}; };
using ScaleScalePass = UnaryCombinedOp<Scale, Scale, PassThrough>;
using ScaleScaleRelu = UnaryCombinedOp<Scale, Scale, Relu>;
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
}(); }();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N // not pad M or N
return c_grid_desc_mraw_nraw; return c_grid_desc_mraw_nraw;
} }
#endif
} }
struct Problem struct Problem
...@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
...@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
...@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value || if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value)) is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{ {
if(!karg.IsReduceAdd()) if(!karg.IsReduceAdd())
{ {
......
...@@ -36,10 +36,9 @@ __global__ void ...@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
...@@ -56,7 +55,7 @@ __global__ void ...@@ -56,7 +55,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -69,10 +68,9 @@ __global__ void ...@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -93,7 +91,7 @@ __global__ void ...@@ -93,7 +91,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
...@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
} }
}(); }();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization; using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N // not pad M or N
return c_grid_desc_mraw_nraw; return c_grid_desc_mraw_nraw;
} }
#endif
} }
__host__ __device__ static auto MakeDsGridDescriptor_M_N( __host__ __device__ static auto MakeDsGridDescriptor_M_N(
...@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{ {
...@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{ {
......
...@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset); dst_wave_addr_offset);
} }
template <typename T, index_t N>
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
T* addr)
{
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, half_t>::value)
{
vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
tmp.template AsType<bhalf2_t>()[i]);
});
}
#endif
}
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data, __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
...@@ -907,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -907,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using scalar_t = typename scalar_type<vector_t>::type; using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
if constexpr(is_same<T, bhalf_t>::value)
{
if(dst_thread_element_valid)
{
amd_global_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, p_dst_wave + dst_thread_element_offset);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>( amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
} #else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif #endif
}
} }
// buffer_atomic_max requires: // buffer_atomic_max requires:
......
...@@ -358,13 +358,15 @@ struct DynamicBuffer ...@@ -358,13 +358,15 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, int32_t> || is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
is_same_v<remove_cvref_t<scalar_t>, float> || is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0); (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>; bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing = bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, float> || is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0); (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#else #else
bool constexpr use_amd_buffer_addressing = false; bool constexpr use_amd_buffer_addressing = false;
#endif #endif
......
...@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1> ...@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
}; };
// 2D XOR, NOTE: "xor" is a keyword // 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift> template <typename LowLengths>
struct xor_t : public base_transform<2, 2> struct xor_t : public base_transform<2, 2>
{ {
static constexpr auto type_enum = coord_transform_enum::xor_t; static constexpr auto type_enum = coord_transform_enum::xor_t;
...@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2> ...@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths; using UpLengths = LowLengths;
UpLengths up_lengths_; UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {} CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths, CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{ {
...@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2> ...@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}]; idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp = idx_low(number<1>{}) =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}]; idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2> ...@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{ {
return ck_tile::is_known_at_compile_time<UpLengths>::value && return ck_tile::is_known_at_compile_time<UpLengths>::value;
ck_tile::is_known_at_compile_time<RightShift>::value;
} }
// MUST be static function // MUST be static function
...@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2> ...@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths; array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides; array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides); return make_tuple(up_vector_lengths, up_vector_strides);
} }
...@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2> ...@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}"); printf("}");
} }
}; };
...@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, ...@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length}; return modulo<Modulus, UpLength>{modulus, up_length};
} }
template <typename LowLengths, typename RightShift> template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths, CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
const RightShift& right_shift)
{ {
return xor_t<LowLengths, RightShift>{low_lengths, right_shift}; return xor_t<LowLengths>{low_lengths};
} }
template <typename LowLength, typename OffsetLength> template <typename LowLength, typename OffsetLength>
......
...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); ...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64))); using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16 // i16
// using int16_t = ... // using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( ...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple( return make_tuple(
make_static_tile_distribution( make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths, tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// h_lengths type // change the
// h_lengths type
typename Encoding::Ps2RHssMajor, typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor, typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor, typename Encoding::Ys2RHsMajor,
......
...@@ -53,6 +53,39 @@ class philox ...@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private: private:
struct ull2 struct ull2
{ {
......
...@@ -8,21 +8,16 @@ ...@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
...@@ -286,11 +286,226 @@ struct BlockDropout ...@@ -286,11 +286,226 @@ struct BlockDropout
}); });
} }
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = true;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_)
: ph(seed,
offset + (i_batch * nheads + i_head) * get_warp_size() +
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
// except headdim256.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
if constexpr(IsWG32)
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
}
else
{
if constexpr(IsWG32)
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm, template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType, typename RandValOutputDataType,
typename PComputeWindow, typename PComputeWindow,
typename RandValDramWindow> typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute, PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const RandValDramWindow& randval_dram_window) const
{ {
...@@ -305,30 +520,177 @@ struct BlockDropout ...@@ -305,30 +520,177 @@ struct BlockDropout
constexpr index_t kMPerStep = MWarp * WG::kM; constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN; constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute // register distribute
auto randval = auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>()); make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval.kThreadElementSpaceSize == 16); static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); auto randval_lds_read_window =
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { make_tile_window(randval_lds_window.get_bottom_tensor_view(),
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { randval_lds_window.get_window_lengths(),
int block_row_start = (start_m0_idx / WG::kM) + i_m0; randval_lds_window.get_window_origin(),
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); MakeRandValLdsShuffleTileDistribution<BlockGemm>());
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start); uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number // generate random number
uint8_t random_uint8_t[16]; uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol)); ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t[i_random_idx++]; p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
constexpr auto p_idx0 = ? p_compute[p_idx] * rp_undrop
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{}; : PComputeDataType(0);
});
});
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval = make_static_distributed_tensor<uint8_t>(
MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start, block_col_start;
if constexpr(IsWG32)
{
block_row_start = (start_m0_idx / WG::kM) + i_m0;
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
}
else
{
block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t* random_uint8_t_;
if constexpr(MBwdWG16SingleIterCheck)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const index_t start_idx =
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
}
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{}; constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
...@@ -337,19 +699,19 @@ struct BlockDropout ...@@ -337,19 +699,19 @@ struct BlockDropout
}); });
}); });
// save to Global // save to Global
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
const auto randval_store = cast_tile<RandValOutputDataType>(randval); const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store); store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0}); move_tile_window(randval_dram_window, {kMPerStep, 0});
} }
}); });
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
} }
}); });
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
} }
...@@ -358,7 +720,6 @@ struct BlockDropout ...@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile::philox ph; ck_tile::philox ph;
const float rp_undrop; const float rp_undrop;
const uint8_t p_undrop_in_uint8_t; const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment