Commit 51fc99a8 authored by Anthony Chang's avatar Anthony Chang
Browse files

adds acc0 elementwise op to interface

parent 8672733f
...@@ -48,10 +48,11 @@ using B0Layout = Col; ...@@ -48,10 +48,11 @@ using B0Layout = Col;
using B1Layout = Row; using B1Layout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
using B1ElementOp = PassThrough; using Acc0ElementOp = PassThrough;
using CElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
...@@ -68,6 +69,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X ...@@ -68,6 +69,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X
CShuffleDataType, CShuffleDataType,
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
...@@ -287,10 +289,11 @@ int main(int argc, char* argv[]) ...@@ -287,10 +289,11 @@ int main(int argc, char* argv[])
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
auto b1_element_op = B1ElementOp{}; auto acc0_element_op = Acc0ElementOp{};
auto c_element_op = CElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
...@@ -315,6 +318,7 @@ int main(int argc, char* argv[]) ...@@ -315,6 +318,7 @@ int main(int argc, char* argv[])
BatchStrideC, BatchStrideC,
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op);
......
...@@ -22,6 +22,7 @@ template <typename ALayout, ...@@ -22,6 +22,7 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceBatchedGemmGemm : public BaseOperator struct DeviceBatchedGemmGemm : public BaseOperator
...@@ -46,6 +47,7 @@ struct DeviceBatchedGemmGemm : public BaseOperator ...@@ -46,6 +47,7 @@ struct DeviceBatchedGemmGemm : public BaseOperator
ck::index_t BatchStrideC, ck::index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op) = 0;
...@@ -62,6 +64,7 @@ template <typename ALayout, ...@@ -62,6 +64,7 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename B0ElementwiseOperation, typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout, using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
...@@ -74,6 +77,7 @@ using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout, ...@@ -74,6 +77,7 @@ using DeviceBatchedGemmGemmPtr = std::unique_ptr<DeviceBatchedGemmGemm<ALayout,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
B0ElementwiseOperation, B0ElementwiseOperation,
Acc0ElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation>>; CElementwiseOperation>>;
......
...@@ -25,6 +25,7 @@ template <typename GridwiseGemm, ...@@ -25,6 +25,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -45,6 +46,7 @@ __global__ void ...@@ -45,6 +46,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
...@@ -78,6 +80,7 @@ __global__ void ...@@ -78,6 +80,7 @@ __global__ void
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
...@@ -92,6 +95,7 @@ __global__ void ...@@ -92,6 +95,7 @@ __global__ void
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op; ignore = b1_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
...@@ -119,6 +123,7 @@ template <typename ALayout, ...@@ -119,6 +123,7 @@ template <typename ALayout,
typename CShuffleDataType, typename CShuffleDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
...@@ -173,6 +178,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -173,6 +178,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
...@@ -550,6 +556,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -550,6 +556,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -624,6 +631,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -624,6 +631,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t BatchStrideC, index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
...@@ -639,6 +647,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -639,6 +647,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
...@@ -670,6 +679,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -670,6 +679,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
...@@ -708,6 +718,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -708,6 +718,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
...@@ -729,6 +740,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -729,6 +740,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
arg.p_c_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_, arg.b1_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
...@@ -807,14 +819,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -807,14 +819,15 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t BatchStrideC, index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{p_a, p_b, p_b1, p_c, MRaw, return Argument{p_a, p_b, p_b1, p_c, MRaw,
NRaw, KRaw, Gemm1NRaw, Batch, StrideA, NRaw, KRaw, Gemm1NRaw, Batch, StrideA,
StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB, StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB,
BatchStrideB1, BatchStrideC, a_element_op, b_element_op, b1_element_op, BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op,
c_element_op}; b1_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -839,6 +852,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -839,6 +852,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
index_t BatchStrideC, index_t BatchStrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
...@@ -861,6 +875,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -861,6 +875,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
BatchStrideC, BatchStrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
acc_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op);
} }
......
...@@ -23,6 +23,7 @@ template <typename FloatAB, ...@@ -23,6 +23,7 @@ template <typename FloatAB,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -319,6 +320,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -319,6 +320,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op, const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
...@@ -544,10 +546,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -544,10 +546,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
FloatAB, FloatAB,
decltype(acc_thread_desc_k0_m_k1), decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1), decltype(a1_thread_desc_k0_m_k1),
decltype(acc_element_op),
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>, Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
2, 2,
n4>{}; n4>{acc_element_op};
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
......
...@@ -1202,6 +1202,7 @@ template <typename SrcData, ...@@ -1202,6 +1202,7 @@ template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -1214,7 +1215,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1214,7 +1215,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic() __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
const ElementwiseOperation& element_op)
: element_op_{element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time"); "wrong! Desc need to known at compile-time");
...@@ -1277,10 +1280,18 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1277,10 +1280,18 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = src_buf[Number<src_offset>{}]; SrcData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
}); });
}); });
} }
ElementwiseOperation element_op_;
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment