Commit bd0f0686 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents e9b1000f 63914743
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceContractionMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
std::vector<index_t> a_ms_ks_lengths,
std::vector<index_t> a_ms_ks_strides,
std::vector<index_t> b_ns_ks_lengths,
std::vector<index_t> b_ns_ks_strides,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths,
std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides,
std::vector<index_t> e_ms_ns_lengths,
std::vector<index_t> e_ms_ns_strides,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -871,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -871,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -711,7 +711,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -711,7 +711,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -996,8 +996,30 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -996,8 +996,30 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
0, 0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType))); sizeof(CDataType)));
float elapsed_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
return launch_and_time_kernel(stream_config, hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_and_time_kernel(StreamConfig{nullptr, false},
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -1012,6 +1034,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1012,6 +1034,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
return elapsed_time;
}; };
// run kernel for bf16 with splitk // run kernel for bf16 with splitk
...@@ -1022,7 +1046,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1022,7 +1046,8 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType))); sizeof(AccDataType)));
return launch_and_time_kernel(stream_config, float elapsed_time =
launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -1037,6 +1062,30 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1037,6 +1062,30 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
hipGetErrorString(hipMemset(
arg.p_workspace_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(AccDataType)));
launch_and_time_kernel(StreamConfig{nullptr, false},
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
static_cast<AccDataType*>(arg.p_workspace_),
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
return elapsed_time;
}; };
// kernel for type conversion // kernel for type conversion
...@@ -1210,6 +1259,20 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1210,6 +1259,20 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NumDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
}
// vector load A/B matrix from global memory // vector load A/B matrix from global memory
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
...@@ -1334,6 +1397,12 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -1334,6 +1397,12 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock
<< ">"; << ">";
if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0){
str << " Filter1x1Stride1Pad0";
}
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -1033,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -1033,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -2,21 +2,29 @@ ...@@ -2,21 +2,29 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator struct DeviceGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
...@@ -27,17 +35,29 @@ struct DeviceGemm : public BaseOperator ...@@ -27,17 +35,29 @@ struct DeviceGemm : public BaseOperator
ck::index_t StrideC, ck::index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) = 0;
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmPtr = std::unique_ptr< using DeviceGemmPtr = std::unique_ptr<DeviceGemm<ALayout,
DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBiasActivationAdd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasActivationAddPtr =
std::unique_ptr<DeviceGemmBiasActivationAdd<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -3,33 +3,45 @@ ...@@ -3,33 +3,45 @@
#pragma once #pragma once
#include <iostream> #include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct DEGridDesc_M0_M1_M2_N0_N1
{
ck::index_t M0_, M1_, M2_, N0_, N1_;
ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_;
};
// input : A[M, K], B[K, N],
// input : D[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D)
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceGemmBias : public BaseOperator struct DeviceGemmBiasCPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
const void* p_bias, const void* p_d,
void* p_c, void* p_e,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
ck::index_t StrideC, DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc,
DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0; CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
...@@ -37,8 +49,8 @@ struct DeviceGemmBias : public BaseOperator ...@@ -37,8 +49,8 @@ struct DeviceGemmBias : public BaseOperator
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr< using DeviceGemmBiasCPermutePtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -64,8 +64,16 @@ template < ...@@ -64,8 +64,16 @@ template <
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct DeviceGemmDl struct DeviceGemmDl : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -534,8 +542,7 @@ struct DeviceGemmDl ...@@ -534,8 +542,7 @@ struct DeviceGemmDl
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
...@@ -11,17 +11,28 @@ namespace ck { ...@@ -11,17 +11,28 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// GEMM:
// input : A[M, K], B[K, N], // input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ... // input : D0[M, N], D1[M, N], ...
// output : E[M, N] // output : E[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <ck::index_t NumDTensor, // Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceGemmMultipleD : public BaseOperator struct DeviceGemmMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
...@@ -41,14 +52,26 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -41,14 +52,26 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::index_t NumDTensor, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
using DeviceGemmMultipleDPtr = std::unique_ptr<DeviceGemmMultipleD<NumDTensor, using DeviceGemmMultipleDPtr = std::unique_ptr<DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>>; CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -88,15 +88,18 @@ namespace ck { ...@@ -88,15 +88,18 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// input : A[M, K], or A[K, N] // GEMM:
// input : B[K, N], or A[N, K] // input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ... // input : D0[M, N], D1[M, N], ...
// output : E[M, N] // output : E[M, N]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CDELayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -137,7 +140,13 @@ template <typename ALayout, ...@@ -137,7 +140,13 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType::Size(), struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
...@@ -357,15 +366,15 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -357,15 +366,15 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
} }
} }
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CDELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1)); make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CDELayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE)); make_tuple(I1, StrideE));
...@@ -417,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -417,7 +426,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
...@@ -490,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -490,7 +499,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -512,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -512,7 +521,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n = const auto d_grid_desc_m_n =
DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -521,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -521,23 +530,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
} }
} }
// ck::Tuple<const DsDataType*...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cv_t<decltype(DsDataType{}.At(i))>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< StaticallyIndexedArray<
...@@ -548,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -548,7 +548,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
...@@ -740,7 +744,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType: ...@@ -740,7 +744,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include "device_base.hpp" #include "device_base.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// FIXME: DeviceGemmReduce type need to well define the problem
template <ck::index_t NumDTensor, ck::index_t NumReduce> template <ck::index_t NumDTensor, ck::index_t NumReduce>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -11,7 +12,13 @@ namespace ck { ...@@ -11,7 +12,13 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGemmSplitK : public BaseOperator struct DeviceGemmSplitK : public BaseOperator
...@@ -33,11 +40,24 @@ struct DeviceGemmSplitK : public BaseOperator ...@@ -33,11 +40,24 @@ struct DeviceGemmSplitK : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceGemmSplitKPtr = std::unique_ptr< using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -57,8 +57,15 @@ template <typename ADataType, ...@@ -57,8 +57,15 @@ template <typename ADataType,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1> ck::index_t NumPrefetch = 1>
struct DeviceGemmXdl struct DeviceGemmXdl : public DeviceGemm<ALayout,
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -487,8 +494,7 @@ struct DeviceGemmXdl ...@@ -487,8 +494,7 @@ struct DeviceGemmXdl
index_t StrideC, index_t StrideC,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment