Commit 3c959547 authored by Jing Zhang's avatar Jing Zhang
Browse files

add multiD to gemm_c_permute

parent 85978e02
...@@ -34,7 +34,8 @@ using ADataType = F16; ...@@ -34,7 +34,8 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using DDataType = F16; using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
...@@ -54,7 +55,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xd ...@@ -54,7 +55,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xd
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>; < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -74,8 +75,9 @@ int main(int argc, char* argv[]) ...@@ -74,8 +75,9 @@ int main(int argc, char* argv[])
ck::index_t N = N0 * N1; ck::index_t N = N0 * N1;
ck::index_t K = 128; ck::index_t K = 128;
ck::index_t stride_A = K; ck::index_t stride_A = K;
ck::index_t stride_B = K; ck::index_t stride_B = K;
ck::index_t stride_D0 = 0;
#if 1 #if 1
// E = [M0, N0, M1, N1, M2] // E = [M0, N0, M1, N1, M2]
...@@ -84,21 +86,7 @@ int main(int argc, char* argv[]) ...@@ -84,21 +86,7 @@ int main(int argc, char* argv[])
ck::index_t stride_E_M2 = 1; ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2; ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2; ck::index_t stride_E_N1 = M2;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
#else #else
// D = [0, 0, 0, N0, N1]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
// E = [M0, M1, M2, N0, N1] // E = [M0, M1, M2, N0, N1]
ck::index_t stride_E_M0 = M1 * M2 * N0 * N1; ck::index_t stride_E_M0 = M1 * M2 * N0 * N1;
ck::index_t stride_E_M1 = M2 * N0 * N1; ck::index_t stride_E_M1 = M2 * N0 * N1;
...@@ -107,9 +95,7 @@ int main(int argc, char* argv[]) ...@@ -107,9 +95,7 @@ int main(int argc, char* argv[])
ck::index_t stride_E_N1 = 1; ck::index_t stride_E_N1 = 1;
#endif #endif
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{ const ck::tensor_operation::device::EGridDesc_M0_M1_M2_N0_N1 e_grid_desc{
M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1};
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{
M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1}; M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1};
if(argc == 1) if(argc == 1)
...@@ -145,17 +131,17 @@ int main(int argc, char* argv[]) ...@@ -145,17 +131,17 @@ int main(int argc, char* argv[])
}; };
auto f_host_de_tensor_descriptor = auto f_host_de_tensor_descriptor =
[](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) { [](ck::tensor_operation::device::EGridDesc_M0_M1_M2_N0_N1 e_grid_desc_) {
std::size_t m0 = de_grid_desc.M0_; std::size_t m0 = e_grid_desc_.M0_;
std::size_t m1 = de_grid_desc.M1_; std::size_t m1 = e_grid_desc_.M1_;
std::size_t m2 = de_grid_desc.M2_; std::size_t m2 = e_grid_desc_.M2_;
std::size_t n0 = de_grid_desc.N0_; std::size_t n0 = e_grid_desc_.N0_;
std::size_t n1 = de_grid_desc.N1_; std::size_t n1 = e_grid_desc_.N1_;
std::size_t stride_m0 = de_grid_desc.stride_M0_; std::size_t stride_m0 = e_grid_desc_.stride_M0_;
std::size_t stride_m1 = de_grid_desc.stride_M1_; std::size_t stride_m1 = e_grid_desc_.stride_M1_;
std::size_t stride_m2 = de_grid_desc.stride_M2_; std::size_t stride_m2 = e_grid_desc_.stride_M2_;
std::size_t stride_n0 = de_grid_desc.stride_N0_; std::size_t stride_n0 = e_grid_desc_.stride_N0_;
std::size_t stride_n1 = de_grid_desc.stride_N1_; std::size_t stride_n1 = e_grid_desc_.stride_N1_;
return HostTensorDescriptor( return HostTensorDescriptor(
std::vector<std::size_t>({m0, m1, m2, n0, n1}), std::vector<std::size_t>({m0, m1, m2, n0, n1}),
std::vector<std::size_t>({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1})); std::vector<std::size_t>({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1}));
...@@ -163,13 +149,13 @@ int main(int argc, char* argv[]) ...@@ -163,13 +149,13 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
Tensor<DDataType> d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc)); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, stride_D0, DLayout{}));
Tensor<EDataType> e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc)); Tensor<EDataType> e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc)); Tensor<EDataType> e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl; std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl; std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
...@@ -178,24 +164,23 @@ int main(int argc, char* argv[]) ...@@ -178,24 +164,23 @@ int main(int argc, char* argv[])
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) * DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
d_m0_m1_m2_n0_n1.mDesc.GetElementSpace());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) * DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) *
e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace()); e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data()); d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -206,14 +191,14 @@ int main(int argc, char* argv[]) ...@@ -206,14 +191,14 @@ int main(int argc, char* argv[])
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(), b_k_n_device_buf.GetDeviceBuffer(),
d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), {d0_m_n_device_buf.GetDeviceBuffer()},
e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(), e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
d_grid_desc, {stride_D0},
e_grid_desc, e_grid_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -228,7 +213,7 @@ int main(int argc, char* argv[]) ...@@ -228,7 +213,7 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * N + sizeof(EDataType) * M * N; sizeof(D0DataType) * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -269,7 +254,7 @@ int main(int argc, char* argv[]) ...@@ -269,7 +254,7 @@ int main(int argc, char* argv[])
cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1), cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1),
ck::type_convert<EDataType>(c_m_n(m, n)), ck::type_convert<EDataType>(c_m_n(m, n)),
d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1)); d0_m_n(m, n));
} }
e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data()); e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data());
......
...@@ -11,34 +11,48 @@ namespace ck { ...@@ -11,34 +11,48 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct DEGridDesc_M0_M1_M2_N0_N1 struct EGridDesc_M0_M1_M2_N0_N1
{ {
ck::index_t M0_, M1_, M2_, N0_, N1_; ck::index_t M0_, M1_, M2_, N0_, N1_;
ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_; ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_;
}; };
// input : A[M, K], B[K, N], // GEMM:
// input : D[M, N], ... // input : A[M, K], B[K, N],
// output : E[M, N] // input : D0[M, N], D1[M, N], ...
// C = a_op(A) * b_op(B) // output : E[M, N]
// E = cde_op(C, D) // C = a_op(A) * b_op(B)
template <typename AElementwiseOperation, // E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... have the same layout
template <typename ALayout,
typename BLayout,
typename DLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceGemmBiasCPermute : public BaseOperator struct DeviceGemmBiasCPermute : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_d, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, std::array<ck::index_t, NumDTensor> StrideDs,
DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, EGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
...@@ -46,12 +60,6 @@ struct DeviceGemmBiasCPermute : public BaseOperator ...@@ -46,12 +60,6 @@ struct DeviceGemmBiasCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasCPermutePtr = std::unique_ptr<
DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -96,12 +96,12 @@ namespace device { ...@@ -96,12 +96,12 @@ namespace device {
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename DLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DDataType, typename DsDataType,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -137,19 +137,26 @@ template <typename ALayout, ...@@ -137,19 +137,26 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation, struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<ALayout,
BLayout,
DLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGemmBiasCPermute_Xdl; using DeviceOp = DeviceGemmBiasCPermute_Xdl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t NumDTensor = I1;
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -356,19 +363,19 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -356,19 +363,19 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
} }
} }
static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) static auto MakeEGridDescriptor_M_N(EGridDesc_M0_M1_M2_N0_N1 e_grid_desc)
{ {
index_t M0 = d_e_grid_desc.M0_; index_t M0 = e_grid_desc.M0_;
index_t M1 = d_e_grid_desc.M1_; index_t M1 = e_grid_desc.M1_;
index_t M2 = d_e_grid_desc.M2_; index_t M2 = e_grid_desc.M2_;
index_t N0 = d_e_grid_desc.N0_; index_t N0 = e_grid_desc.N0_;
index_t N1 = d_e_grid_desc.N1_; index_t N1 = e_grid_desc.N1_;
index_t stride_M0 = d_e_grid_desc.stride_M0_; index_t stride_M0 = e_grid_desc.stride_M0_;
index_t stride_M1 = d_e_grid_desc.stride_M1_; index_t stride_M1 = e_grid_desc.stride_M1_;
index_t stride_M2 = d_e_grid_desc.stride_M2_; index_t stride_M2 = e_grid_desc.stride_M2_;
index_t stride_N0 = d_e_grid_desc.stride_N0_; index_t stride_N0 = e_grid_desc.stride_N0_;
index_t stride_N1 = d_e_grid_desc.stride_N1_; index_t stride_N1 = e_grid_desc.stride_N1_;
const auto MRaw = M0 * M1 * M2; const auto MRaw = M0 * M1 * M2;
const auto NRaw = N0 * N1; const auto NRaw = N0 * N1;
...@@ -429,16 +436,74 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -429,16 +436,74 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
} }
} }
static auto MakeDGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD)
{
const auto d_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideD, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideD));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(d_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
d_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
d_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return d_grid_desc_mraw_nraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(EGridDesc_M0_M1_M2_N0_N1{}));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<DDataType>, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -480,20 +545,24 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -480,20 +545,24 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using DGridDesc_M_N = decltype(MakeDGridDescriptor_M_N(1, 1, 1));
using DGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const void* p_a_grid, Argument(const void* p_a_grid,
const void* p_b_grid, const void* p_b_grid,
const void* p_d_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, std::array<index_t, NumDTensor> StrideDs,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, EGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
...@@ -512,16 +581,6 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -512,16 +581,6 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op}
{ {
if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_)
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
...@@ -531,13 +590,18 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -531,13 +590,18 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid); static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = const auto d_grid_desc_m_n =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DeviceOp::MakeDGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
d_grid_desc_m_n);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
} }
} }
...@@ -546,17 +610,19 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -546,17 +610,19 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< StaticallyIndexedArray<DGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, NumDTensor>
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E // type from E
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
...@@ -596,9 +662,8 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -596,9 +662,8 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray< ck::StaticallyIndexedArray<DGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, NumDTensor>,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
...@@ -665,29 +730,29 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -665,29 +730,29 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
static auto MakeArgument(const void* p_a, static auto MakeArgument(const void* p_a,
const void* p_b, const void* p_b,
const void* p_d, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, std::array<index_t, NumDTensor> StrideDs,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, EGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_d, p_ds,
p_e, p_e,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
StrideA, StrideA,
StrideB, StrideB,
d_grid_desc, StrideDs,
e_grid_desc, e_grid_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -700,29 +765,29 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -700,29 +765,29 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_d, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, std::array<ck::index_t, NumDTensor> StrideDs,
DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, EGridDesc_M0_M1_M2_N0_N1 e_grid_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override CDEElementwiseOperation cde_element_op) override
{ {
return std::make_unique<Argument>(p_a, return std::make_unique<Argument>(p_a,
p_b, p_b,
p_d, p_ds,
p_e, p_e,
MRaw, MRaw,
NRaw, NRaw,
KRaw, KRaw,
StrideA, StrideA,
StrideB, StrideB,
d_grid_desc, StrideDs,
e_grid_desc, e_grid_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -210,8 +210,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -210,8 +210,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
template <typename DEGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc_M_N& e_grid_desc_m_n)
{ {
const auto M = e_grid_desc_m_n.GetLength(I0); const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1); const auto N = e_grid_desc_m_n.GetLength(I1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment