Commit 51fc99a8 authored by Anthony Chang's avatar Anthony Chang
Browse files

adds acc0 elementwise op to interface

parent 8672733f
......@@ -48,10 +48,11 @@ using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......@@ -68,6 +69,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
......@@ -287,10 +289,11 @@ int main(int argc, char* argv[])
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
......@@ -315,6 +318,7 @@ int main(int argc, char* argv[])
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
......
......@@ -22,6 +22,7 @@ template <typename ALayout,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmGemm : public BaseOperator
......@@ -46,6 +47,7 @@ struct DeviceBatchedGemmGemm : public BaseOperator
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
......@@ -62,6 +64,7 @@ template <typename ALayout,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
......@@ -74,6 +77,7 @@ using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
Acc0ElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>>;
......
......@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -45,6 +46,7 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
......@@ -78,6 +80,7 @@ __global__ void
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
......@@ -92,6 +95,7 @@ __global__ void
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
......@@ -119,6 +123,7 @@ template <typename ALayout,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
......@@ -173,6 +178,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{
......@@ -550,6 +556,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
......@@ -624,6 +631,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
......@@ -639,6 +647,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
batch_count_(Batch),
......@@ -670,6 +679,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
index_t batch_count_;
......@@ -708,6 +718,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
......@@ -729,6 +740,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
......@@ -807,14 +819,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_b1, p_c, MRaw,
NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, b1_element_op,
c_element_op};
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
b1_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -839,6 +852,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t BatchStrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
......@@ -861,6 +875,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
BatchStrideC,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
......
......@@ -23,6 +23,7 @@ template <typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -319,6 +320,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
......@@ -544,10 +546,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
FloatAB,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
decltype(acc_element_op),
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{};
n4>{acc_element_op};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
......
......@@ -1202,6 +1202,7 @@ template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
......@@ -1214,7 +1215,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic()
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
const ElementwiseOperation& element_op)
: element_op_{element_op}
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
......@@ -1277,10 +1280,18 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = src_buf[Number<src_offset>{}];
SrcData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
});
});
}
ElementwiseOperation element_op_;
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment