Commit f0201ead authored by Chao Liu's avatar Chao Liu
Browse files

gemm+activation

parent 64350aff
...@@ -20,7 +20,8 @@ template <typename GridwiseGemm, ...@@ -20,7 +20,8 @@ template <typename GridwiseGemm,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2, typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
typename CElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -32,7 +33,8 @@ __global__ void ...@@ -32,7 +33,8 @@ __global__ void
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map,
const CElementwiseOperation c_elementwise_op)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -46,7 +48,8 @@ __global__ void ...@@ -46,7 +48,8 @@ __global__ void
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map); block_2_ctile_map,
c_elementwise_op);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -102,6 +105,7 @@ template <index_t BlockSize, ...@@ -102,6 +105,7 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename CElementwiseOp,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -353,7 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -353,7 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const CElementwiseOp& c_elementwise_op)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
...@@ -573,6 +578,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -573,6 +578,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_naive_tensor_descriptor_packed(make_tuple( make_naive_tensor_descriptor_packed(make_tuple(
Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1)); Number<M0>{}, Number<N0>{}, I1, I1, Number<M2>{}, I1, Number<M4>{}, I1));
// elementwise Op to C
static_for<0, c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize(), 1>{}(
[&](auto i) { c_thread_buf(i) = c_elementwise_op(c_thread_buf[i]); });
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
......
...@@ -36,6 +36,15 @@ struct BaseOperator ...@@ -36,6 +36,15 @@ struct BaseOperator
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
struct BaseGpuOperator
{
BaseGpuOperator() = default;
BaseGpuOperator(const BaseGpuOperator&) = default;
BaseGpuOperator& operator=(const BaseGpuOperator&) = default;
virtual ~BaseGpuOperator() {}
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
......
...@@ -10,15 +10,17 @@ namespace device { ...@@ -10,15 +10,17 @@ namespace device {
struct DeviceGemm : public BaseOperator struct DeviceGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
ck::index_t M, void* p_c,
ck::index_t N, ck::index_t M,
ck::index_t K, ck::index_t N,
ck::index_t StrideA, ck::index_t K,
ck::index_t StrideB, ck::index_t StrideA,
ck::index_t StrideC) = 0; ck::index_t StrideB,
ck::index_t StrideC,
std::unique_ptr<BaseGpuOperator> c_element_op_ptr) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -22,6 +22,7 @@ template <typename ADataType, ...@@ -22,6 +22,7 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename CElementwiseOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -176,6 +177,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -176,6 +177,7 @@ struct DeviceGemmXdl : public DeviceGemm
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
CElementwiseOperation,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -230,7 +232,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -230,7 +232,8 @@ struct DeviceGemmXdl : public DeviceGemm
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t M01, index_t M01,
index_t N01) index_t N01,
CElementwiseOperation c_elementwise_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -240,7 +243,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -240,7 +243,8 @@ struct DeviceGemmXdl : public DeviceGemm
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01} N01_{N01},
c_elementwise_op_{c_elementwise_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -267,6 +271,7 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -267,6 +271,7 @@ struct DeviceGemmXdl : public DeviceGemm
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
CElementwiseOperation c_elementwise_op_;
}; };
// Invoker // Invoker
...@@ -317,7 +322,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -317,7 +322,8 @@ struct DeviceGemmXdl : public DeviceGemm
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
true>; true,
CElementwiseOperation>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -330,7 +336,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -330,7 +336,8 @@ struct DeviceGemmXdl : public DeviceGemm
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_,
arg.c_elementwise_op_);
} }
else else
{ {
...@@ -342,7 +349,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -342,7 +349,8 @@ struct DeviceGemmXdl : public DeviceGemm
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>, remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
false>; false,
CElementwiseOperation>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -355,7 +363,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -355,7 +363,8 @@ struct DeviceGemmXdl : public DeviceGemm
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_,
arg.c_elementwise_op_);
} }
return ave_time; return ave_time;
...@@ -397,23 +406,37 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -397,23 +406,37 @@ struct DeviceGemmXdl : public DeviceGemm
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC) index_t StrideC,
std::unique_ptr<BaseGpuOperator> c_op_ptr)
{ {
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1}; return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
*dynamic_cast<CElementwiseOperation*>(c_op_ptr.get())};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
index_t M, void* p_c,
index_t N, index_t M,
index_t K, index_t N,
index_t StrideA, index_t K,
index_t StrideB, index_t StrideA,
index_t StrideC) override index_t StrideB,
index_t StrideC,
std::unique_ptr<BaseGpuOperator> c_op_ptr) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -425,7 +448,8 @@ struct DeviceGemmXdl : public DeviceGemm ...@@ -425,7 +448,8 @@ struct DeviceGemmXdl : public DeviceGemm
StrideB, StrideB,
StrideC, StrideC,
1, 1,
1); 1,
*dynamic_cast<CElementwiseOperation*>(c_op_ptr.get()));
} }
// polymorphic // polymorphic
......
...@@ -14,21 +14,36 @@ ...@@ -14,21 +14,36 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
struct Activation : public ck::tensor_operation::device::BaseGpuOperator
{
float alpha = 0.1;
// ReLU
template <typename T>
__host__ __device__ T operator()(T v) const
{
T tmp = alpha * v;
return tmp > 0 ? tmp : 0;
}
};
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout,
typename CElementwiseOperation>
struct DeviceGemmInstance; struct DeviceGemmInstance;
template <> template <typename CElementwiseOperation>
struct DeviceGemmInstance<ck::half_t, struct DeviceGemmInstance<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor> ck::tensor_layout::gemm::RowMajor,
CElementwiseOperation>
{ {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -42,21 +57,22 @@ struct DeviceGemmInstance<ck::half_t, ...@@ -42,21 +57,22 @@ struct DeviceGemmInstance<ck::half_t,
// Compilation parameters for NT problem // Compilation parameters for NT problem
// clang-format off // clang-format off
using type = using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| CElementwiseOperation| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| Type| Type| Type| Type| | | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>; ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, CElementwiseOperation, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
// clang-format on // clang-format on
}; };
template <> template <typename CElementwiseOperation>
struct DeviceGemmInstance<float, struct DeviceGemmInstance<float,
float, float,
float, float,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor> ck::tensor_layout::gemm::RowMajor,
CElementwiseOperation>
{ {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -70,11 +86,11 @@ struct DeviceGemmInstance<float, ...@@ -70,11 +86,11 @@ struct DeviceGemmInstance<float,
// Compilation parameters for NT problem // Compilation parameters for NT problem
// clang-format off // clang-format off
using type = using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| CElementwiseOperation| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN| //########################################| Type| Type| Type| Type| | | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | | //########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>; ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, CElementwiseOperation, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
// clang-format on // clang-format on
}; };
...@@ -155,9 +171,15 @@ int main(int argc, char* argv[]) ...@@ -155,9 +171,15 @@ int main(int argc, char* argv[])
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
// do GEMM // do GEMM
auto gemm = auto gemm = typename DeviceGemmInstance<ADataType,
typename DeviceGemmInstance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>:: BDataType,
type{}; CDataType,
ALayout,
BLayout,
CLayout,
Activation>::type{};
auto activation = Activation{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
...@@ -168,7 +190,8 @@ int main(int argc, char* argv[]) ...@@ -168,7 +190,8 @@ int main(int argc, char* argv[])
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC); StrideC,
std::make_unique<Activation>(activation));
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -194,7 +217,7 @@ int main(int argc, char* argv[]) ...@@ -194,7 +217,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result); host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, activation);
check_error(c_m_n_host_result, c_m_n_device_result); check_error(c_m_n_host_result, c_m_n_device_result);
} }
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <typename AType, typename BType, typename CType> template <typename AType, typename BType, typename CType, typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n, const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n) Tensor<CType>& c_m_n,
const CElementwiseOperation& c_element_op)
{ {
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1]; const int K = a_m_k.mDesc.GetLengths()[1];
...@@ -16,7 +17,7 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k, ...@@ -16,7 +17,7 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n)); v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
} }
c_m_n(m, n) = v; c_m_n(m, n) = c_element_op(v);
}; };
make_ParallelTensorFunctor(f_mk_kn_mn, make_ParallelTensorFunctor(f_mk_kn_mn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment