Commit 496e2ec6 authored by Chao Liu's avatar Chao Liu
Browse files

move C pointwise operation into threadwise copy

parent f0201ead
...@@ -34,7 +34,7 @@ __global__ void ...@@ -34,7 +34,7 @@ __global__ void
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const CElementwiseOperation c_elementwise_op) const CElementwiseOperation c_element_op)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -49,7 +49,7 @@ __global__ void ...@@ -49,7 +49,7 @@ __global__ void
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map, block_2_ctile_map,
c_elementwise_op); c_element_op);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -58,7 +58,8 @@ template <typename GridwiseGemm, ...@@ -58,7 +58,8 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap> typename Block2CTileMap,
typename CElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -69,7 +70,8 @@ __global__ void ...@@ -69,7 +70,8 @@ __global__ void
const void CONSTANT* p_a_grid_desc_k0_m_k1, const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1, const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_block_2_ctile_map) const void CONSTANT* p_block_2_ctile_map,
const void CONSTANT* p_c_element_op)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -83,6 +85,8 @@ __global__ void ...@@ -83,6 +85,8 @@ __global__ void
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2)); cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>( const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map)); cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
...@@ -93,7 +97,8 @@ __global__ void ...@@ -93,7 +97,8 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map); block_2_ctile_map,
c_element_op);
} }
#endif #endif
...@@ -105,7 +110,7 @@ template <index_t BlockSize, ...@@ -105,7 +110,7 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename CElementwiseOp, typename CElementwiseOperation,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -358,7 +363,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -358,7 +363,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const CElementwiseOp& c_elementwise_op) const CElementwiseOperation& c_element_op)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -578,10 +583,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -578,10 +583,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_naive_tensor_descriptor_packed(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1)); Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// elementwise Op to C
static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}(
[&](auto i) { c_thread_buf(i) = c_elementwise_op(c_thread_buf[i]); });
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -619,6 +620,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -619,6 +620,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatC, FloatC,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
CElementwiseOperation,
Sequence<M0, N0, I1, I1, M2, I1, M4, I1>, Sequence<M0, N0, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
...@@ -626,7 +628,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -626,7 +628,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(m_thread_data_on_grid_idx[I0], make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
...@@ -635,7 +636,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -635,7 +636,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4], m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2])}; n_thread_data_on_grid_idx[I2]),
c_element_op};
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
......
...@@ -50,6 +50,7 @@ template <typename SrcData, ...@@ -50,6 +50,7 @@ template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename ElementwiseOp,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -69,8 +70,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -69,8 +70,10 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx) const Index& dst_slice_origin_idx,
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) const ElementwiseOp element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
element_op_{element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -195,8 +198,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -195,8 +198,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_buf[Number<src_offset>{}]); type_convert<DstData>(element_op_(src_buf[Number<src_offset>{}]));
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -373,6 +377,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -373,6 +377,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
ElementwiseOp element_op_;
}; // namespace ck }; // namespace ck
// Assume: // Assume:
......
...@@ -36,15 +36,6 @@ struct BaseOperator ...@@ -36,15 +36,6 @@ struct BaseOperator
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
struct BaseGpuOperator
{
BaseGpuOperator() = default;
BaseGpuOperator(const BaseGpuOperator&) = default;
BaseGpuOperator& operator=(const BaseGpuOperator&) = default;
virtual ~BaseGpuOperator() {}
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
......
...@@ -8,6 +8,7 @@ namespace ck { ...@@ -8,6 +8,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator struct DeviceGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -20,12 +21,13 @@ struct DeviceGemm : public BaseOperator ...@@ -20,12 +21,13 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
std::unique_ptr<BaseGpuOperator> c_element_op_ptr) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
using DeviceGemmPtr = std::unique_ptr<DeviceGemm>; template <typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<DeviceGemm<CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -50,7 +50,7 @@ template <typename ADataType, ...@@ -50,7 +50,7 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM, bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN> bool BBlockLdsAddExtraN>
struct DeviceGemmXdl : public DeviceGemm struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -233,7 +233,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -233,7 +233,7 @@ struct DeviceGemmXdl : public DeviceGemm
index_t StrideC, index_t StrideC,
index_t M01, index_t M01,
index_t N01, index_t N01,
CElementwiseOperation c_elementwise_op) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -244,7 +244,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -244,7 +244,7 @@ struct DeviceGemmXdl : public DeviceGemm
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
c_elementwise_op_{c_elementwise_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -271,7 +271,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -271,7 +271,7 @@ struct DeviceGemmXdl : public DeviceGemm
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
CElementwiseOperation c_elementwise_op_; CElementwiseOperation c_element_op_;
}; };
// Invoker // Invoker
...@@ -337,7 +337,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -337,7 +337,7 @@ struct DeviceGemmXdl : public DeviceGemm
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.c_elementwise_op_); arg.c_element_op_);
} }
else else
{ {
...@@ -364,7 +364,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -364,7 +364,7 @@ struct DeviceGemmXdl : public DeviceGemm
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.c_elementwise_op_); arg.c_element_op_);
} }
return ave_time; return ave_time;
...@@ -407,36 +407,24 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -407,36 +407,24 @@ struct DeviceGemmXdl : public DeviceGemm
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
std::unique_ptr<BaseGpuOperator> c_op_ptr) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1, c_element_op};
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
*dynamic_cast<CElementwiseOperation*>(c_op_ptr.get())};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
MakeArgumentPointer(const void* p_a, const void* p_b,
const void* p_b, void* p_c,
void* p_c, index_t M,
index_t M, index_t N,
index_t N, index_t K,
index_t K, index_t StrideA,
index_t StrideA, index_t StrideB,
index_t StrideB, index_t StrideC,
index_t StrideC, CElementwiseOperation c_element_op) override
std::unique_ptr<BaseGpuOperator> c_op_ptr) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -449,7 +437,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -449,7 +437,7 @@ struct DeviceGemmXdl : public DeviceGemm
StrideC, StrideC,
1, 1,
1, 1,
*dynamic_cast<CElementwiseOperation*>(c_op_ptr.get())); c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
struct Activation : public ck::tensor_operation::device::BaseGpuOperator struct Activation
{ {
float alpha = 0.1; float alpha = 0.1;
...@@ -191,7 +191,7 @@ int main(int argc, char* argv[]) ...@@ -191,7 +191,7 @@ int main(int argc, char* argv[])
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
std::make_unique<Activation>(activation)); activation);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment