Commit 2066a3d4 authored by Chao Liu's avatar Chao Liu
Browse files

add pointwise operation to A/B matrix

parent 496e2ec6
...@@ -14,6 +14,7 @@ namespace ck { ...@@ -14,6 +14,7 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize, template <index_t BlockSize,
typename SrcElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths, typename BlockSliceLengths,
typename ThreadSliceLengths, typename ThreadSliceLengths,
...@@ -39,12 +40,17 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -39,12 +40,17 @@ struct BlockwiseTensorSliceTransfer_v4
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc, __device__ constexpr BlockwiseTensorSliceTransfer_v4(
const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_block_slice_origin) const Index& dst_block_slice_origin,
: threadwise_transfer_( const SrcElementwiseOperation& src_element_op)
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>()) : threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
src_element_op)
{ {
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() && static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
...@@ -147,6 +153,7 @@ struct BlockwiseTensorSliceTransfer_v4 ...@@ -147,6 +153,7 @@ struct BlockwiseTensorSliceTransfer_v4
using ThreadwiseTransfer = using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths, ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
SrcElementwiseOperation,
DstInMemOp, DstInMemOp,
SrcData, SrcData,
DstData, DstData,
......
...@@ -19,9 +19,11 @@ template <typename GridwiseGemm, ...@@ -19,9 +19,11 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop>
typename CElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -33,8 +35,10 @@ __global__ void ...@@ -33,8 +35,10 @@ __global__ void
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap block_2_ctile_map, const AElementwiseOperation a_element_op,
const CElementwiseOperation c_element_op) const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -48,8 +52,10 @@ __global__ void ...@@ -48,8 +52,10 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map, a_element_op,
c_element_op); b_element_op,
c_element_op,
block_2_ctile_map);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -58,8 +64,10 @@ template <typename GridwiseGemm, ...@@ -58,8 +64,10 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap, typename AElementwiseOperation,
typename CElementwiseOperation> typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -70,8 +78,10 @@ __global__ void ...@@ -70,8 +78,10 @@ __global__ void
const void CONSTANT* p_a_grid_desc_k0_m_k1, const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1, const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_block_2_ctile_map, const void CONSTANT* p_a_element_op,
const void CONSTANT* p_c_element_op) const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -85,6 +95,10 @@ __global__ void ...@@ -85,6 +95,10 @@ __global__ void
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2)); cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>( const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map)); cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>( const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op)); cast_pointer_to_generic_address_space(p_c_element_op));
...@@ -97,8 +111,10 @@ __global__ void ...@@ -97,8 +111,10 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map, a_element_op,
c_element_op); b_element_op,
c_element_op,
block_2_ctile_map);
} }
#endif #endif
...@@ -110,6 +126,8 @@ template <index_t BlockSize, ...@@ -110,6 +126,8 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
...@@ -362,8 +380,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -362,8 +380,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap& block_2_ctile_map, const AElementwiseOperation& a_element_op,
const CElementwiseOperation& c_element_op) const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -421,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -421,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_K0_M_K1,
...@@ -442,11 +463,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -442,11 +463,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true>(a_grid_desc_k0_m_k1, true>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0),
a_element_op);
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_K0_N_K1,
...@@ -468,7 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -468,7 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true>(b_grid_desc_k0_n_k1, true>(b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0)); make_multi_index(0, 0, 0),
b_element_op);
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -50,7 +50,7 @@ template <typename SrcData, ...@@ -50,7 +50,7 @@ template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename ElementwiseOp, typename SrcElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -69,11 +69,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -69,11 +69,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
const DstDesc& dst_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const ElementwiseOp element_op) const SrcElementwiseOperation src_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
element_op_{element_op} src_element_op_{src_element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -200,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -200,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// apply element-wise operation and type convert // apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(element_op_(src_buf[Number<src_offset>{}])); type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}]));
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -377,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -377,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
ElementwiseOp element_op_; SrcElementwiseOperation src_element_op_;
}; // namespace ck }; // namespace ck
// Assume: // Assume:
......
...@@ -46,6 +46,7 @@ struct lambda_scalar_per_access_for_src_and_dst ...@@ -46,6 +46,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 3. src_slice_origin and dst_slice_origin are not known at compile-time, // 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer // 4. Use thread buffer
template <typename SliceLengths, template <typename SliceLengths,
typename SrcElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData, typename SrcData,
typename DstData, typename DstData,
...@@ -76,12 +77,15 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -76,12 +77,15 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin) const Index& dst_slice_origin,
const SrcElementwiseOperation& src_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op)
{ {
} }
...@@ -191,12 +195,22 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -191,12 +195,22 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_t = typename vector_type_maker_t<SrcData, SrcScalarPerVector>::type; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf to src_thread_scratch_ // copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
src_vector_container.template AsType<SrcData>()(i) =
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>( src_thread_scratch_.template SetAsType<src_vector_t>(
src_data_idx_seq, src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid));
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -796,6 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -796,6 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
SrcElementwiseOperation src_element_op_;
}; };
} // namespace ck } // namespace ck
......
...@@ -8,7 +8,9 @@ namespace ck { ...@@ -8,7 +8,9 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename CElementwiseOperation> template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator struct DeviceGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -21,13 +23,18 @@ struct DeviceGemm : public BaseOperator ...@@ -21,13 +23,18 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename CElementwiseOperation> template <typename AElementwiseOperation,
using DeviceGemmPtr = std::unique_ptr<DeviceGemm<CElementwiseOperation>>; typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -22,6 +22,8 @@ template <typename ADataType, ...@@ -22,6 +22,8 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
...@@ -50,7 +52,8 @@ template <typename ADataType, ...@@ -50,7 +52,8 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM, bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN> bool BBlockLdsAddExtraN>
struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> struct DeviceGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -177,6 +180,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -177,6 +180,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
...@@ -233,6 +238,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -233,6 +238,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t StrideC, index_t StrideC,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
...@@ -244,6 +251,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -244,6 +251,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
...@@ -271,6 +280,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -271,6 +280,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
}; };
...@@ -321,9 +332,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -321,9 +332,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
true, true>;
CElementwiseOperation>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -336,8 +349,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -336,8 +349,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_, arg.a_element_op_,
arg.c_element_op_); arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -348,9 +363,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -348,9 +363,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
false, false>;
CElementwiseOperation>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -363,8 +380,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -363,8 +380,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_, arg.a_element_op_,
arg.c_element_op_); arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -407,9 +426,24 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -407,9 +426,24 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1, c_element_op}; return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -424,6 +458,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -424,6 +458,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -437,6 +473,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation> ...@@ -437,6 +473,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
StrideC, StrideC,
1, 1,
1, 1,
a_element_op,
b_element_op,
c_element_op); c_element_op);
} }
......
...@@ -14,13 +14,22 @@ ...@@ -14,13 +14,22 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
struct Activation struct Equal
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct Relu
{ {
float alpha = 0.1; float alpha = 0.1;
// ReLU // ReLU
template <typename T> template <typename T>
__host__ __device__ T operator()(T v) const __host__ __device__ constexpr T operator()(T v) const
{ {
T tmp = alpha * v; T tmp = alpha * v;
return tmp > 0 ? tmp : 0; return tmp > 0 ? tmp : 0;
...@@ -33,16 +42,22 @@ template <typename ADataType, ...@@ -33,16 +42,22 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmInstance; struct DeviceGemmInstance;
template <typename CElementwiseOperation> template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmInstance<ck::half_t, struct DeviceGemmInstance<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -54,24 +69,32 @@ struct DeviceGemmInstance<ck::half_t, ...@@ -54,24 +69,32 @@ struct DeviceGemmInstance<ck::half_t,
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using AOp = AElementwiseOperation;
using BOp = BElementwiseOperation;
using COp = CElementwiseOperation;
// Compilation parameters for NT problem // Compilation parameters for NT problem
// clang-format off // clang-format off
using type = using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| CElementwiseOperation| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, CElementwiseOperation, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
// clang-format on // clang-format on
}; };
template <typename CElementwiseOperation> template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmInstance<float, struct DeviceGemmInstance<float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -83,14 +106,18 @@ struct DeviceGemmInstance<float, ...@@ -83,14 +106,18 @@ struct DeviceGemmInstance<float,
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using AOp = AElementwiseOperation;
using BOp = BElementwiseOperation;
using COp = CElementwiseOperation;
// Compilation parameters for NT problem // Compilation parameters for NT problem
// clang-format off // clang-format off
using type = using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| CElementwiseOperation| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, CElementwiseOperation, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>; ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
// clang-format on // clang-format on
}; };
...@@ -177,9 +204,9 @@ int main(int argc, char* argv[]) ...@@ -177,9 +204,9 @@ int main(int argc, char* argv[])
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
Activation>::type{}; Equal,
Equal,
auto activation = Activation{}; Relu>::type{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
...@@ -191,7 +218,9 @@ int main(int argc, char* argv[]) ...@@ -191,7 +218,9 @@ int main(int argc, char* argv[])
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
activation); Equal{},
Equal{},
Relu{});
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -217,7 +246,7 @@ int main(int argc, char* argv[]) ...@@ -217,7 +246,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, activation); host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, Equal{}, Equal{}, Relu{});
check_error(c_m_n_host_result, c_m_n_device_result); check_error(c_m_n_host_result, c_m_n_device_result);
} }
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <typename AType, typename BType, typename CType, typename CElementwiseOperation> template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n, const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n, Tensor<CType>& c_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op) const CElementwiseOperation& c_element_op)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
...@@ -14,7 +21,8 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, ...@@ -14,7 +21,8 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n)); v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
} }
c_m_n(m, n) = c_element_op(v); c_m_n(m, n) = c_element_op(v);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment