Commit 6de61c04 authored by ltqin's avatar ltqin
Browse files

clear code

parent fb6dafee
......@@ -41,8 +41,8 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#define USING_SKIP_LDS 1
#if USING_SKIP_LDS
// clang-format off
#if USING_SKIP_LDS
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
......@@ -61,69 +61,21 @@ using BDataType = float;
using CDataType = float;
using AccDataType = float;
#endif
// clang-format on
#else
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|
//Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer|
//ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer|
//BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
//CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise|
//Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per|
//Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar|
//DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim|
//SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation|
//Operation| | | | | | | | | Wave| Wave|
//Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1|
//| Lengths_K0_N_K1| ArrangeOrder| | | PerVector|
//PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | |
//| | | | | | | | | | | | | |
//| | | | | | | |
//|
<F32,
F32,
F32,
F32,
Row,
Col,
Row,
PassThrough,
PassThrough,
PassThrough,
GemmDefault,
64,
16,
16,
4,
4,
16,
16,
1,
1,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
S<4, 16, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
7,
1>;
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, 2>;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
#endif
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -52,8 +52,7 @@ template <typename ADataType,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1>
ck::index_t CThreadTransferDstScalarPerVector>
struct DeviceGemmXdlSkipLds
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{
......@@ -219,8 +218,7 @@ struct DeviceGemmXdlSkipLds
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumPrefetch>;
CThreadTransferDstScalarPerVector>;
// Argument
struct Argument : public BaseArgument
......@@ -353,7 +351,6 @@ struct DeviceGemmXdlSkipLds
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
......@@ -386,7 +383,6 @@ struct DeviceGemmXdlSkipLds
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
......
......@@ -34,7 +34,6 @@ __global__ void
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation a_element_op,
......@@ -50,7 +49,6 @@ __global__ void
p_c_grid,
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
a_element_op,
......@@ -109,8 +107,7 @@ template <index_t BlockSize,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumPrefetch = 1>
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
{
static constexpr auto I0 = Number<0>{};
......@@ -219,21 +216,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// check NumPrefetch
if constexpr(NumPrefetch == 1)
{
// 1-stage prefetch always supported
}
else if constexpr(NumPrefetch == 2)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
}
else
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
{
return false;
}
......@@ -266,7 +251,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / (MultiK0 * NumPrefetch * K0PerBlock)) > 1;
const bool has_main_k0_block_loop = (K0 / (MultiK0 * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
......@@ -410,7 +395,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation& a_element_op,
......@@ -421,7 +405,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
......@@ -464,13 +448,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumPrefetch>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
1>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
ignore = b_element_op;
// B matrix blockwise copy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment