Commit 2066a3d4 authored by Chao Liu's avatar Chao Liu
Browse files

add pointwise operation to A/B matrix

parent 496e2ec6
......@@ -14,6 +14,7 @@ namespace ck {
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename SrcElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
......@@ -39,12 +40,17 @@ struct BlockwiseTensorSliceTransfer_v4
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
__device__ constexpr BlockwiseTensorSliceTransfer_v4(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const SrcElementwiseOperation& src_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
src_element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
......@@ -147,6 +153,7 @@ struct BlockwiseTensorSliceTransfer_v4
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<ThreadSliceLengths,
SrcElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
......
......@@ -19,9 +19,11 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap,
bool HasMainKBlockLoop,
typename CElementwiseOperation>
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -33,8 +35,10 @@ __global__ void
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap block_2_ctile_map,
const CElementwiseOperation c_element_op)
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const Block2CTileMap block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -48,8 +52,10 @@ __global__ void
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map,
c_element_op);
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
......@@ -58,8 +64,10 @@ template <typename GridwiseGemm,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap,
typename CElementwiseOperation>
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -70,8 +78,10 @@ __global__ void
const void CONSTANT* p_a_grid_desc_k0_m_k1,
const void CONSTANT* p_b_grid_desc_k0_n_k1,
const void CONSTANT* p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const void CONSTANT* p_block_2_ctile_map,
const void CONSTANT* p_c_element_op)
const void CONSTANT* p_a_element_op,
const void CONSTANT* p_b_element_op,
const void CONSTANT* p_c_element_op,
const void CONSTANT* p_block_2_ctile_map)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -85,6 +95,10 @@ __global__ void
cast_pointer_to_generic_address_space(p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2));
const auto block_2_ctile_map = *reinterpret_cast<const Block2CTileMap*>(
cast_pointer_to_generic_address_space(p_block_2_ctile_map));
const auto a_element_op = *reinterpret_cast<const AElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_a_element_op));
const auto b_element_op = *reinterpret_cast<const BElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_b_element_op));
const auto c_element_op = *reinterpret_cast<const CElementwiseOperation*>(
cast_pointer_to_generic_address_space(p_c_element_op));
......@@ -97,8 +111,10 @@ __global__ void
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
block_2_ctile_map,
c_element_op);
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
}
#endif
......@@ -110,6 +126,8 @@ template <index_t BlockSize,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
......@@ -362,8 +380,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const Block2CTileMap& block_2_ctile_map,
const CElementwiseOperation& c_element_op)
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -421,6 +441,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
AElementwiseOperation,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
......@@ -442,11 +463,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0));
make_multi_index(0, 0, 0),
a_element_op);
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
BElementwiseOperation,
InMemoryDataOperationEnum_t::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
......@@ -468,7 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
true>(b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0));
make_multi_index(0, 0, 0),
b_element_op);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......
......@@ -50,7 +50,7 @@ template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOp,
typename SrcElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
......@@ -69,11 +69,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const ElementwiseOp element_op)
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
element_op_{element_op}
src_element_op_{src_element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
......@@ -200,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(element_op_(src_buf[Number<src_offset>{}]));
type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}]));
});
const bool is_dst_valid =
......@@ -377,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
DstCoord dst_coord_;
ElementwiseOp element_op_;
SrcElementwiseOperation src_element_op_;
}; // namespace ck
// Assume:
......
......@@ -46,6 +46,7 @@ struct lambda_scalar_per_access_for_src_and_dst
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths,
typename SrcElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData,
typename DstData,
......@@ -76,12 +77,15 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin)
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const SrcElementwiseOperation& src_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op)
{
}
......@@ -191,12 +195,22 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_t = typename vector_type_maker_t<SrcData, SrcScalarPerVector>::type;
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf to src_thread_scratch_
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
src_vector_container.template AsType<SrcData>()(i) =
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>(
src_data_idx_seq,
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid));
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
......@@ -796,6 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_;
DstCoord dst_coord_;
SrcElementwiseOperation src_element_op_;
};
} // namespace ck
......
......@@ -8,7 +8,9 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename CElementwiseOperation>
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
......@@ -21,13 +23,18 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<DeviceGemm<CElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr<
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -22,6 +22,8 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t BlockSize,
ck::index_t MPerBlock,
......@@ -50,7 +52,8 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector,
bool ABlockLdsAddExtraM,
bool BBlockLdsAddExtraN>
struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
struct DeviceGemmXdl
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -177,6 +180,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
......@@ -233,6 +238,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t StrideC,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
......@@ -244,6 +251,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
......@@ -271,6 +280,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
......@@ -321,9 +332,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
true,
CElementwiseOperation>;
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -336,8 +349,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_,
arg.c_element_op_);
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
else
{
......@@ -348,9 +363,11 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
false,
CElementwiseOperation>;
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -363,8 +380,10 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_,
arg.c_element_op_);
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
return ave_time;
......@@ -407,9 +426,24 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1, c_element_op};
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -424,6 +458,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......@@ -437,6 +473,8 @@ struct DeviceGemmXdl : public DeviceGemm<CElementwiseOperation>
StrideC,
1,
1,
a_element_op,
b_element_op,
c_element_op);
}
......
......@@ -14,13 +14,22 @@
#include "device_base.hpp"
#include "device_gemm_xdl.hpp"
struct Activation
struct Equal
{
template <typename T>
__host__ __device__ constexpr T operator()(T v) const
{
return v;
}
};
struct Relu
{
float alpha = 0.1;
// ReLU
template <typename T>
__host__ __device__ T operator()(T v) const
__host__ __device__ constexpr T operator()(T v) const
{
T tmp = alpha * v;
return tmp > 0 ? tmp : 0;
......@@ -33,16 +42,22 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmInstance;
template <typename CElementwiseOperation>
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmInstance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using F16 = ck::half_t;
......@@ -54,24 +69,32 @@ struct DeviceGemmInstance<ck::half_t,
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using AOp = AElementwiseOperation;
using BOp = BElementwiseOperation;
using COp = CElementwiseOperation;
// Compilation parameters for NT problem
// clang-format off
using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| CElementwiseOperation| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, CElementwiseOperation, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
// clang-format on
};
template <typename CElementwiseOperation>
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmInstance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using F16 = ck::half_t;
......@@ -83,14 +106,18 @@ struct DeviceGemmInstance<float,
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using AOp = AElementwiseOperation;
using BOp = BElementwiseOperation;
using COp = CElementwiseOperation;
// Compilation parameters for NT problem
// clang-format off
using type =
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| CElementwiseOperation| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, CElementwiseOperation, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
// clang-format on
};
......@@ -177,9 +204,9 @@ int main(int argc, char* argv[])
ALayout,
BLayout,
CLayout,
Activation>::type{};
auto activation = Activation{};
Equal,
Equal,
Relu>::type{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
......@@ -191,7 +218,9 @@ int main(int argc, char* argv[])
StrideA,
StrideB,
StrideC,
activation);
Equal{},
Equal{},
Relu{});
if(!gemm.IsSupportedArgument(argument))
{
......@@ -217,7 +246,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, activation);
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, Equal{}, Equal{}, Relu{});
check_error(c_m_n_host_result, c_m_n_device_result);
}
......
#pragma once
#include "host_tensor.hpp"
template <typename AType, typename BType, typename CType, typename CElementwiseOperation>
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
......@@ -14,7 +21,8 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
}
c_m_n(m, n) = c_element_op(v);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment