Commit 6ef4e211 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into contraction

parents b0a2afb9 9e4429f9
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -113,7 +113,7 @@ __global__ void ...@@ -113,7 +113,7 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif
} }
template <typename ADataType, template <typename ADataType,
...@@ -151,8 +151,15 @@ template <typename ADataType, ...@@ -151,8 +151,15 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceBatchedGemmXdl struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -334,26 +341,26 @@ struct DeviceBatchedGemmXdl ...@@ -334,26 +341,26 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op)
index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(BatchCount), Batch_(Batch),
a_grid_desc_k0_m_k1_{ a_grid_desc_k0_m_k1_{
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
b_grid_desc_k0_n_k1_{ b_grid_desc_k0_n_k1_{
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC},
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
block_2_ctile_map_{ block_2_ctile_map_{
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
M01_{M01}, M01_{M01},
...@@ -376,7 +383,7 @@ struct DeviceBatchedGemmXdl ...@@ -376,7 +383,7 @@ struct DeviceBatchedGemmXdl
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t Batch_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
...@@ -420,7 +427,7 @@ struct DeviceBatchedGemmXdl ...@@ -420,7 +427,7 @@ struct DeviceBatchedGemmXdl
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
...@@ -451,7 +458,7 @@ struct DeviceBatchedGemmXdl ...@@ -451,7 +458,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -485,7 +492,7 @@ struct DeviceBatchedGemmXdl ...@@ -485,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.Batch_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -536,10 +543,13 @@ struct DeviceBatchedGemmXdl ...@@ -536,10 +543,13 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op)
index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -550,12 +560,15 @@ struct DeviceBatchedGemmXdl ...@@ -550,12 +560,15 @@ struct DeviceBatchedGemmXdl
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op};
BatchCount};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -570,10 +583,13 @@ struct DeviceBatchedGemmXdl ...@@ -570,10 +583,13 @@ struct DeviceBatchedGemmXdl
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -584,12 +600,15 @@ struct DeviceBatchedGemmXdl ...@@ -584,12 +600,15 @@ struct DeviceBatchedGemmXdl
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
BatchStrideA,
BatchStrideB,
BatchStrideC,
Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op);
BatchCount);
} }
// polymorphic // polymorphic
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace ck { namespace ck {
...@@ -25,7 +26,7 @@ template <typename ADataType, ...@@ -25,7 +26,7 @@ template <typename ADataType,
index_t AScalarPerVector, index_t AScalarPerVector,
index_t BScalarPerVector, index_t BScalarPerVector,
index_t CScalarPerVector> index_t CScalarPerVector>
struct DeviceBinaryElementwise : public BaseOperator struct DeviceBinaryElementwise : public DeviceElementwise<2, 1, NDim, ElementwiseFunctor>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -198,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(std::array<const void*, 2> p_inputs,
void* p_c, std::array<void*, 1> p_outputs,
std::vector<index_t> lengths, std::vector<index_t> lengths,
std::vector<index_t> a_strides, std::vector<std::vector<index_t>> input_strides,
std::vector<index_t> b_strides, std::vector<std::vector<index_t>> output_strides,
std::vector<index_t> c_strides, ElementwiseFunctor functor) override
ElementwiseFunctor functor)
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_inputs[0]),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_inputs[1]),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_outputs[0]),
lengths, lengths,
a_strides, input_strides[0],
b_strides, input_strides[1],
c_strides, output_strides[0],
functor); functor);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); } std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
// polymorphic
std::string GetTypeString() const override std::string GetTypeString() const override
{ {
auto str = std::stringstream(); auto str = std::stringstream();
...@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator ...@@ -226,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off // clang-format off
str << "DeviceBinaryElementwise" str << "DeviceBinaryElementwise"
<< "<" << "<"
<< "NDim = " << NDim
<< "MPerThread = " << MPerThread << "MPerThread = " << MPerThread
<< "AScalarPerVector = " << AScalarPerVector
<< "BScalarPerVector = " << BScalarPerVector
<< "CScalarPerVector = " << CScalarPerVector
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -871,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -871,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -711,7 +711,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -711,7 +711,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1033,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -1033,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
struct DeviceElementwise : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumInputTensor> p_inputs,
std::array<void*, NumOutputTensor> p_outputs,
std::vector<index_t> lengths,
std::vector<std::vector<index_t>> input_strides,
std::vector<std::vector<index_t>> output_strides,
ElementwiseFunctor functor) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <ck::index_t NumInputTensor,
ck::index_t NumOutputTensor,
index_t NDim,
typename ElementwiseFunctor>
using DeviceElementwisePtr =
std::unique_ptr<DeviceElementwise<NumInputTensor, NumOutputTensor, NDim, ElementwiseFunctor>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -18,33 +18,52 @@ struct GemmShape ...@@ -18,33 +18,52 @@ struct GemmShape
ck::index_t StrideA, StrideB, StrideC; ck::index_t StrideA, StrideB, StrideC;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator struct DeviceGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
ck::index_t M, void* p_c,
ck::index_t N, ck::index_t M,
ck::index_t K, ck::index_t N,
ck::index_t StrideA, ck::index_t K,
ck::index_t StrideB, ck::index_t StrideA,
ck::index_t StrideC, ck::index_t StrideB,
AElementwiseOperation a_element_op, ck::index_t StrideC,
BElementwiseOperation b_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op, BElementwiseOperation b_element_op,
ck::index_t KBatch = 1) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr< using DeviceGemmPtr = std::unique_ptr<DeviceGemm<ALayout,
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
......
...@@ -3,33 +3,45 @@ ...@@ -3,33 +3,45 @@
#pragma once #pragma once
#include <iostream> #include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct DEGridDesc_M0_M1_M2_N0_N1
{
ck::index_t M0_, M1_, M2_, N0_, N1_;
ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceGemmBias : public BaseOperator struct DeviceGemmBiasCPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias, const void* p_d,
void* p_c, void* p_e,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
...@@ -37,8 +49,8 @@ struct DeviceGemmBias : public BaseOperator ...@@ -37,8 +49,8 @@ struct DeviceGemmBias : public BaseOperator
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr< using DeviceGemmBiasCPermutePtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -64,8 +64,16 @@ template < ...@@ -64,8 +64,16 @@ template <
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct DeviceGemmDl struct DeviceGemmDl : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -534,8 +542,7 @@ struct DeviceGemmDl ...@@ -534,8 +542,7 @@ struct DeviceGemmDl
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
...@@ -16,12 +16,20 @@ namespace device { ...@@ -16,12 +16,20 @@ namespace device {
// output : E[M, N] // output : E[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <ck::index_t NumDTensor, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceGemmMultipleD : public BaseOperator struct DeviceGemmMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -41,14 +49,26 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -41,14 +49,26 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::index_t NumDTensor, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
using DeviceGemmMultipleDPtr = std::unique_ptr<DeviceGemmMultipleD<NumDTensor, using DeviceGemmMultipleDPtr = std::unique_ptr<DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>>; CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -96,7 +96,7 @@ namespace device { ...@@ -96,7 +96,7 @@ namespace device {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -137,7 +137,13 @@ template <typename ALayout, ...@@ -137,7 +137,13 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType::Size(), struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
...@@ -360,12 +366,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -360,12 +366,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CDELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1)); make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CDELayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE)); make_tuple(I1, StrideE));
...@@ -740,7 +746,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -740,7 +746,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -2,98 +2,44 @@ ...@@ -2,98 +2,44 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, // FIXME: DeviceGemmReduce type need to well define the problem
typename BElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
void* p_dxs, std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t StrideC1, std::array<ck::index_t, NumDTensor> StrideDs,
AElementwiseOperation a_element_op, std::array<void*, 3> gemm_element_ops,
BElementwiseOperation b_element_op, std::array<void*, NumDTensor> d_element_ops,
CElementwiseOperation c_element_op, std::array<void*, NumReduce> reduce_in_element_ops,
C1ElementwiseOperation c1_element_op, std::array<void*, NumReduce> reduce_out_element_ops,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <ck::index_t NumDTensor, ck::index_t NumReduce>
typename BElementwiseOperation, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce>>;
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP #pragma once
#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#include <iostream> #include <iostream>
#include <vector>
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmBiasActivationAdd : public BaseOperator struct DeviceGemmSplitK : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0,
const void* p_c1,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0; ck::index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmBiasActivationAddPtr = using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
std::unique_ptr<DeviceGemmBiasActivationAdd<AElementwiseOperation, BLayout,
BElementwiseOperation, CLayout,
CElementwiseOperation>>; ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -57,8 +57,15 @@ template <typename ADataType, ...@@ -57,8 +57,15 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1> ck::index_t NumPrefetch = 1>
struct DeviceGemmXdl struct DeviceGemmXdl : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -487,8 +494,7 @@ struct DeviceGemmXdl ...@@ -487,8 +494,7 @@ struct DeviceGemmXdl
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment