Unverified Commit f84e2020 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer Committed by GitHub
Browse files

Merge branch 'develop' into lwpck-1815

parents 408534d4 25935b57
......@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}
inline bool is_gfx101_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
......
......@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem.Next();
// clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......@@ -189,15 +185,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
arg_);
}
else
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
......@@ -214,8 +207,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -224,7 +215,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
......@@ -239,13 +229,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -255,8 +243,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -266,8 +254,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
......@@ -326,8 +313,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
......@@ -354,7 +340,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
......@@ -472,8 +457,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
......@@ -496,7 +479,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
......@@ -524,13 +506,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -539,8 +519,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -548,7 +528,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
......@@ -579,10 +558,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -591,7 +567,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
......@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......
......@@ -86,7 +86,6 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -101,10 +100,8 @@ __global__ void
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
......@@ -200,7 +197,6 @@ __global__ void
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -321,8 +317,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ALayout,
ELayout>;
ADataType,
EDataType>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......@@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = 1;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -19,7 +19,7 @@ namespace device {
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
......@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
......@@ -45,7 +45,7 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
OutElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
......@@ -3,7 +3,6 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
......@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2 unary_op2_{};
};
using ScaleScalePass = UnaryCombinedOp<Scale, Scale, PassThrough>;
using ScaleScaleRelu = UnaryCombinedOp<Scale, Scale, Relu>;
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
......@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct Problem
......@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
......@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value))
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
{
......
......@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
......@@ -56,7 +55,7 @@ __global__ void
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx9__))
}
template <typename GridwiseGemm,
......@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -93,7 +91,7 @@ __global__ void
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx9__))
}
template <typename ALayout,
......@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}();
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
......@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
......@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(karg.M % MPerBlock == 0))
{
......@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{
if(!(karg.N % NPerBlock == 0))
{
......
......@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset);
}
template <typename T, index_t N>
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
T* addr)
{
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, half_t>::value)
{
vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
tmp.template AsType<bhalf2_t>()[i]);
});
}
#endif
}
template <typename T, index_t N>
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
......@@ -907,6 +935,16 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
if constexpr(is_same<T, bhalf_t>::value)
{
if(dst_thread_element_valid)
{
amd_global_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, p_dst_wave + dst_thread_element_offset);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
......@@ -919,6 +957,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
}
// buffer_atomic_max requires:
......
......@@ -358,13 +358,15 @@ struct DynamicBuffer
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
......
......@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift>
template <typename LowLengths>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
......@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths;
UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
......@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
idx_low(number<1>{}) =
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
......@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<RightShift>::value;
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
......@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
......@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
......@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
{
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
return xor_t<LowLengths>{low_lengths};
}
template <typename LowLength, typename OffsetLength>
......
......@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
......@@ -746,7 +746,8 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple(
make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
......
......@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w;
}
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private:
struct ull2
{
......
......@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment