Commit c29dc4c5 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into conv_splitk_f32

parents 134af43b fd3d907a
...@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr, StaticBufferOfVectorTypeV2<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, 16>, vector_type<FloatAcc, xdlops_gemm.GetRegSizePerXdlops()>,
MRepeat * NRepeat, MRepeat * NRepeat,
true> true>
c_thread_buf_; c_thread_buf_;
...@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
...@@ -246,14 +247,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -246,14 +247,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
} }
private: private:
// A[K, M] // A[K0, M0, M1, M2, K1]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
// B[K, N] // B[K0, N0, N1, N2, K1]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
// C[M, N]
static constexpr auto c_thread_desc_ = static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
......
...@@ -14,6 +14,7 @@ namespace ck { ...@@ -14,6 +14,7 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
...@@ -39,12 +40,17 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -39,12 +40,17 @@ struct BlockwiseTensorSliceTransfer_v4
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v4(
const Index& src_block_slice_origin, const SrcDesc& src_desc,
const DstDesc& dst_desc, const Index& src_block_slice_origin,
const Index& dst_block_slice_origin) const DstDesc& dst_desc,
: threadwise_transfer_( const Index& dst_block_slice_origin,
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) const SrcElementwiseOperation& src_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
src_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
...@@ -147,6 +153,7 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -147,6 +153,7 @@ struct BlockwiseTensorSliceTransfer_v4
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
SrcElementwiseOperation,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
...@@ -19,6 +19,9 @@ template <typename GridwiseGemm, ...@@ -19,6 +19,9 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
...@@ -32,6 +35,9 @@ __global__ void ...@@ -32,6 +35,9 @@ __global__ void
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -46,6 +52,9 @@ __global__ void ...@@ -46,6 +52,9 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map); block_2_ctile_map);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...@@ -55,6 +64,9 @@ template <typename GridwiseGemm, ...@@ -55,6 +64,9 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap> typename Block2CTileMap>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -66,6 +78,9 @@ __global__ void ...@@ -66,6 +78,9 @@ __global__ void
const void CONSTANT* p_a_grid_desc_k0_m_k1, const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1, const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map) const void CONSTANT* p_block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -80,6 +95,12 @@ __global__ void ...@@ -80,6 +95,12 @@ __global__ void
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2)); cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>( const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map)); cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
...@@ -90,6 +111,9 @@ __global__ void ...@@ -90,6 +111,9 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map); block_2_ctile_map);
} }
#endif #endif
...@@ -102,6 +126,9 @@ template <index_t BlockSize, ...@@ -102,6 +126,9 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -353,6 +380,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -353,6 +380,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -411,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -411,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -432,11 +463,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -432,11 +463,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true>(a_grid_desc_k0_m_k1, true>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0),
a_element_op);
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
...@@ -458,7 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -458,7 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true>(b_grid_desc_k0_n_k1, true>(b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0),
b_element_op);
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -518,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -518,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// main body // main body
index_t k0_block_data_begin = 0; index_t k0_block_data_begin = 0;
c_thread_buf.Clear();
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
do do
...@@ -557,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -557,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
...@@ -569,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -569,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -610,6 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -610,6 +645,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatC, FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>, Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
...@@ -617,7 +653,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -617,7 +653,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0], make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
...@@ -626,7 +661,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -626,7 +661,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4], m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2])}; n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
......
...@@ -50,6 +50,7 @@ template <typename SrcData, ...@@ -50,6 +50,7 @@ template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename DstElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -68,9 +69,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -68,9 +69,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
const Index& dst_slice_origin_idx) const DstDesc& dst_desc,
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) const Index& dst_slice_origin_idx,
const DstElementwiseOperation& dst_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst_element_op_{dst_element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -195,8 +199,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -195,8 +199,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_buf[Number<src_offset>{}]); type_convert<DstData>(dst_element_op_(src_buf[Number<src_offset>{}]));
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -373,6 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -373,6 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
const DstElementwiseOperation dst_element_op_;
}; // namespace ck }; // namespace ck
// Assume: // Assume:
......
...@@ -46,6 +46,7 @@ struct lambda_scalar_per_access_for_src_and_dst ...@@ -46,6 +46,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer // 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
typename SrcElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData, typename SrcData,
typename DstData, typename DstData,
...@@ -76,12 +77,15 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -76,12 +77,15 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
const Index& src_slice_origin, const SrcDesc& src_desc,
const DstDesc& dst_desc, const Index& src_slice_origin,
const Index& dst_slice_origin) const DstDesc& dst_desc,
const Index& dst_slice_origin,
const SrcElementwiseOperation& src_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op)
{ {
} }
...@@ -191,12 +195,22 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -191,12 +195,22 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_t = typename vector_type_maker_t<SrcData, SrcScalarPerVector>::type; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf to src_thread_scratch_ // copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
src_vector_container.template AsType<SrcData>()(i) =
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>( src_thread_scratch_.template SetAsType<src_vector_t>(
src_data_idx_seq, src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid));
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -796,6 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -796,6 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
}; };
} // namespace ck } // namespace ck
......
...@@ -507,7 +507,7 @@ struct MfmaSelector ...@@ -507,7 +507,7 @@ struct MfmaSelector
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ static constexpr void mfma_check() __host__ __device__ constexpr MfmaSelector()
{ {
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk, selected_mfma.num_regs_per_blk,
...@@ -533,8 +533,6 @@ struct MfmaSelector ...@@ -533,8 +533,6 @@ struct MfmaSelector
"is_k_reduction wrong!"); "is_k_reduction wrong!");
} }
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr bool IsABroadcast() static constexpr bool IsABroadcast()
{ {
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
...@@ -547,7 +545,7 @@ struct MfmaSelector ...@@ -547,7 +545,7 @@ struct MfmaSelector
selected_mfma.k_per_blk; selected_mfma.k_per_blk;
} }
static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
}; };
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
...@@ -621,6 +619,8 @@ struct XdlopsGemm ...@@ -621,6 +619,8 @@ struct XdlopsGemm
return MPerXdlops * NPerXdlops / mfma_instr.wave_size; return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
} }
__device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
...@@ -708,7 +708,7 @@ struct XdlopsGemm ...@@ -708,7 +708,7 @@ struct XdlopsGemm
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetKPerThread(); static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
......
...@@ -141,6 +141,11 @@ ...@@ -141,6 +141,11 @@
#define CK_USE_SPLITK_XDLOPS 1 #define CK_USE_SPLITK_XDLOPS 1
#endif #endif
// workaround for register spill due to compiler issue, when casting type between fp32 and fp16
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1
#endif
namespace ck { namespace ck {
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
......
...@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
static constexpr index_t vector_size = GetVectorSize(); static constexpr index_t vector_size = GetVectorSize();
__host__ __device__ static constexpr index_t GetNumVectors() { return N; }
__host__ __device__ static constexpr index_t GetNumElements()
{
return GetVectorSize() * GetNumVectors();
}
VecBaseType invalid_element_value_ = VecBaseType{0}; VecBaseType invalid_element_value_ = VecBaseType{0};
T invalid_vec_value_ = T{0}; T invalid_vec_value_ = T{0};
...@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N> ...@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
return GetElement(i, true); return GetElement(i, true);
} }
__host__ __device__ void Clear()
{
static_for<0, GetNumElements(), 1>{}(
[&](auto i) { GetElement(i, true) = invalid_element_value_; });
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv_instance.hpp" #include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,32 +19,34 @@ using NHWK = ck::tensor_layout::convolution::NHWK; ...@@ -18,32 +19,34 @@ using NHWK = ck::tensor_layout::convolution::NHWK;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk = std::tuple< using device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk = std::tuple<
// clang-format off // clang-format off
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##############| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> DeviceConvFwdXdl< 2, F16, F16, F16, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
template <> template <>
void add_device_conv_fwd_instance<2, F16, F16, F16, NHWC, KYXC, NHWK>( void add_device_conv_fwd_instance<2, F16, F16, F16, NHWC, KYXC, NHWK>(
std::vector<DeviceConvFwdPtr>& device_conv_instances) std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& device_conv_instances)
{ {
using DeviceConvs = device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk; using DeviceConvs = device_conv_fwd_xdl_instances_f16_f16_f16_nhwc_kyxc_nhwk;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "device_conv_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "device_conv_instance.hpp" #include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,32 +19,34 @@ using NHWK = ck::tensor_layout::convolution::NHWK; ...@@ -18,32 +19,34 @@ using NHWK = ck::tensor_layout::convolution::NHWK;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = std::tuple< using device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk = std::tuple<
// clang-format off // clang-format off
//##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##############| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##############| Spatial| Type| Type| Type| Type| Layout| Layout| Layout| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##############| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##############| | | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##############| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> DeviceConvFwdXdl< 2, F32, F32, F32, F32, NHWC, KYXC, NHWK, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
template <> template <>
void add_device_conv_fwd_instance<2, F32, F32, F32, NHWC, KYXC, NHWK>( void add_device_conv_fwd_instance<2, F32, F32, F32, NHWC, KYXC, NHWK>(
std::vector<DeviceConvFwdPtr>& device_conv_instances) std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& device_conv_instances)
{ {
using DeviceConvs = device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk; using DeviceConvs = device_conv_fwd_xdl_instances_f32_f32_f32_nhwc_kyxc_nhwk;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple< using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
template <> template <>
void add_device_gemm_instance<F16, F16, F16, Col, Row, Row>( void add_device_gemm_instance<F16, F16, F16, Col, Row, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple< using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
template <> template <>
void add_device_gemm_instance<F16, F16, F16, Col, Col, Row>( void add_device_gemm_instance<F16, F16, F16, Col, Col, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -17,27 +18,29 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple< using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true> DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
template <> template <>
void add_device_gemm_instance<F16, F16, F16, Row, Row, Row>( void add_device_gemm_instance<F16, F16, F16, Row, Row, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "config.hpp" #include "config.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -17,32 +18,34 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -17,32 +18,34 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple< using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true> DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
template <> template <>
void add_device_gemm_instance<F16, F16, F16, Row, Col, Row>( void add_device_gemm_instance<F16, F16, F16, Row, Col, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp" #include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
#if CK_USE_SPLITK_XDLOPS #if CK_USE_SPLITK_XDLOPS
using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
...@@ -39,25 +42,25 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< ...@@ -39,25 +42,25 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
#else #else
using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
#endif #endif
template <> template <>
void add_device_gemm_instance<F32, F32, F32, Col, Row, Row>( void add_device_gemm_instance<F32, F32, F32, Col, Row, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp" #include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
#if CK_USE_SPLITK_XDLOPS #if CK_USE_SPLITK_XDLOPS
using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
...@@ -39,25 +42,25 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< ...@@ -39,25 +42,25 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
#else #else
using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
#endif #endif
template <> template <>
void add_device_gemm_instance<F32, F32, F32, Col, Col, Row>( void add_device_gemm_instance<F32, F32, F32, Col, Col, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp" #include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
#if CK_USE_SPLITK_XDLOPS #if CK_USE_SPLITK_XDLOPS
using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
...@@ -39,25 +42,25 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< ...@@ -39,25 +42,25 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
#else #else
using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true> DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
#endif #endif
template <> template <>
void add_device_gemm_instance<F32, F32, F32, Row, Row, Row>( void add_device_gemm_instance<F32, F32, F32, Row, Row, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp" #include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -18,6 +19,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
#if CK_USE_SPLITK_XDLOPS #if CK_USE_SPLITK_XDLOPS
using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
...@@ -44,30 +47,30 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< ...@@ -44,30 +47,30 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
#else #else
using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
#endif #endif
template <> template <>
void add_device_gemm_instance<F32, F32, F32, Row, Col, Row>( void add_device_gemm_instance<F32, F32, F32, Row, Col, Row>(
std::vector<DeviceGemmPtr>& device_op_instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{ {
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment