Commit 5b3bd032 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 76ee0baf
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_cshuffle.hpp" #include "device_gemm_xdl_cshuffle.hpp"
#include "device_gemm_xdl_cshuffle_v2.hpp" #include "device_gemm_xdl_producer_consumer_cshuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
...@@ -57,24 +57,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle ...@@ -57,24 +57,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// // 2-stage prefetch // // 2-stage prefetch
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 0 #elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// all thread
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 0, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 0, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
#elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_v2
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// producer & consumer // < Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>; < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
// < Row, Col, Row, F16, F16, F16, F32, F16, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 8>, 8>;
#elif 1 #elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v2.hpp" #include "gridwise_gemm_xdl_producer_consumer_cshuffle.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -56,10 +56,10 @@ template <typename ALayout, ...@@ -56,10 +56,10 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
struct DeviceGemm_Xdl_CShuffle_v2 struct DeviceGemm_Xdl_ProducerConsumer_CShuffle
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
using DeviceOp = DeviceGemm_Xdl_CShuffle_v2; using DeviceOp = DeviceGemm_Xdl_ProducerConsumer_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -334,7 +334,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -334,7 +334,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -471,7 +471,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -471,7 +471,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v2< const auto kernel = kernel_gemm_xdl_producer_consumer_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -523,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -523,7 +523,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v2< const auto kernel = kernel_gemm_xdl_producer_consumer_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -672,7 +672,7 @@ struct DeviceGemm_Xdl_CShuffle_v2 ...@@ -672,7 +672,7 @@ struct DeviceGemm_Xdl_CShuffle_v2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle_v2" str << "DeviceGemm_Xdl_ProducerConsumer_CShuffle"
<< "<" << "<"
<< ABBlockTransferThreadGroupSize << ", " << ABBlockTransferThreadGroupSize << ", "
<< BlockGemmThreadGroupSize << ", " << BlockGemmThreadGroupSize << ", "
......
#pragma once
#include "common_header.hpp"
namespace ck {
template <typename ABBlockTransferThreadGroup,
typename BlockGemmThreadGroup,
index_t NumGemmKPrefetchStage>
struct GridwiseGemmPipelineProducerConsumer;
// 1-stage prefetch
template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
struct GridwiseGemmPipelineProducerConsumer<ABBlockTransferThreadGroup, BlockGemmThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop / 2 > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_sync_lds();
// move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_block_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_block_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_block_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_sync_lds();
// LDS write num_loop - 1
a_block_copy.RunWrite(a_block_desc, a_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
// GEMM num_loop - 1
}
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
a_block_desc,
a_block_copy,
a_grid_buf,
a_block_buf,
a_block_copy_step,
b_grid_desc,
b_block_desc,
b_block_copy,
b_grid_buf,
b_block_buf,
b_block_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
RunBlockGemmPipeline<HasMainLoop>(
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
}
}
};
} // namespace ck
...@@ -51,7 +51,6 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -51,7 +51,6 @@ struct GridwiseGemmPipeline_v1<1>
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
#if 1
// preload data into LDS // preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -98,80 +97,6 @@ struct GridwiseGemmPipeline_v1<1> ...@@ -98,80 +97,6 @@ struct GridwiseGemmPipeline_v1<1>
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
#elif 1
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
// FIXME: HasMainLoop = (num_loop) > 2
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#endif
} }
}; };
......
#pragma once #pragma once
#include "common_header.hpp" #include "common_header.hpp"
namespace ck { namespace ck {
template <typename ABBlockTransferThreadGroup, typename BlockGemmThreadGroup>
struct GridwiseGemmPipeline_v2 struct GridwiseGemmPipeline_v2
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__device__ constexpr GridwiseGemmPipeline_v2()
{
// TODO static assert
}
__host__ __device__ static constexpr bool IsSupported(index_t num_loop) __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{ {
// TODO: improve applicability // TODO: improve applicability
...@@ -23,7 +13,7 @@ struct GridwiseGemmPipeline_v2 ...@@ -23,7 +13,7 @@ struct GridwiseGemmPipeline_v2
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
return num_loop / 2 > 1; return (num_loop / 2) > 1;
} }
template <bool HasMainLoop, template <bool HasMainLoop,
...@@ -39,41 +29,46 @@ struct GridwiseGemmPipeline_v2 ...@@ -39,41 +29,46 @@ struct GridwiseGemmPipeline_v2
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer> typename CThreadBuffer>
static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc, __device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
index_t num_loop) const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{ {
// global read 0 // global read 0
a_block_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_block_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1 // move to 1
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// LDS write 0 // LDS write 0
a_block_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1 // global Read 1
a_block_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0 // LDS write 0
b_block_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1 // global Read 1
b_block_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// main body // main body
// FIXME: HasMainLoop = (num_loop) > 2
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
{ {
index_t i = 0; index_t i = 0;
...@@ -83,22 +78,23 @@ struct GridwiseGemmPipeline_v2 ...@@ -83,22 +78,23 @@ struct GridwiseGemmPipeline_v2
block_sync_lds(); block_sync_lds();
// GEMM i // GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
// move to i + 2 // move to i + 2
a_block_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_block_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1 // LDS write i + 1
a_block_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2 // global read i + 2
a_block_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1 // LDS write i + 1
b_block_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2 // global read i + 2
b_block_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
++i; ++i;
} while(i < (num_loop - 2)); } while(i < (num_loop - 2));
...@@ -109,128 +105,18 @@ struct GridwiseGemmPipeline_v2 ...@@ -109,128 +105,18 @@ struct GridwiseGemmPipeline_v2
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 2 // GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds(); block_sync_lds();
// LDS write num_loop - 1 // LDS write num_loop - 1
a_block_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds(); block_sync_lds();
// GEMM num_loop - 1 // GEMM num_loop - 1
} blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunBlockGemmPipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 2
// LDS write i + 1
// global read i + 2
// LDS write i + 1
// global read i + 2
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_block_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_block_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
if(ABBlockTransferThreadGroup::IsBelong())
{
RunABBlockTransferPipeline<HasMainLoop>(a_grid_desc,
a_block_desc,
a_block_copy,
a_grid_buf,
a_block_buf,
a_block_copy_step,
b_grid_desc,
b_block_desc,
b_block_copy,
b_grid_buf,
b_block_buf,
b_block_copy_step,
num_loop);
}
else if(BlockGemmThreadGroup::IsBelong())
{
RunBlockGemmPipeline<HasMainLoop>(
a_block_buf, b_block_buf, block_gemm, c_thread_buf, num_loop);
} }
} }
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace ck { namespace ck {
...@@ -127,7 +128,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -127,7 +128,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if 1
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
#else
using GridwiseGemmPipe = GridwiseGemmPipeline_v2;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_producer_consumer.hpp"
#include "gridwise_gemm_pipeline_v2.hpp"
namespace ck { namespace ck {
...@@ -27,18 +26,20 @@ __global__ void ...@@ -27,18 +26,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v2(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_producer_consumer_cshuffle(
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_a_grid,
FloatC* __restrict__ p_c_grid, const FloatAB* __restrict__ p_b_grid,
const AElementwiseOperation a_element_op, FloatC* __restrict__ p_c_grid,
const BElementwiseOperation b_element_op, const AElementwiseOperation a_element_op,
const CElementwiseOperation c_element_op, const BElementwiseOperation b_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const CElementwiseOperation c_element_op,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const Block2CTileMap block_2_ctile_map) c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -52,6 +53,18 @@ __global__ void ...@@ -52,6 +53,18 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
#endif // end of #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename FloatAB, template <typename FloatAB,
...@@ -97,7 +110,7 @@ template <typename FloatAB, ...@@ -97,7 +110,7 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -114,10 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -114,10 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
using ThisThreadBlock =
ThisThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
#if 0
struct ABBlockTransferThreadGroup struct ABBlockTransferThreadGroup
{ {
__device__ static constexpr index_t GetNumOfThread() __device__ static constexpr index_t GetNumOfThread()
...@@ -151,22 +160,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 ...@@ -151,22 +160,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2
} }
}; };
using CShuffleBlockTransferThreadGroup = ThisThreadBlock; using CShuffleBlockTransferThreadGroup =
#else ThisThreadBlock<ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize>;
using ABBlockTransferThreadGroup = ThisThreadBlock;
using BlockGemmThreadGroup = ThisThreadBlock;
using CShuffleBlockTransferThreadGroup = ThisThreadBlock;
#endif
#if 1 using GridwiseGemmPipe = GridwiseGemmPipelineProducerConsumer<ABBlockTransferThreadGroup,
// gridwise GEMM pipeline BlockGemmThreadGroup,
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
#else
// gridwise GEMM pipeline
using GridwiseGemmPipe = GridwiseGemmPipeline_v2<ABBlockTransferThreadGroup,
BlockGemmThreadGroup,
NumGemmKPrefetchStage>;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
......
...@@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N, ...@@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N,
} }
ConvParams::ConvParams() ConvParams::ConvParams()
: num_dim_spatial(2), : num_dim_spatial(2),
N(128), N(128),
K(256), K(256),
C(192), C(192),
filter_spatial_lengths(2, 3), filter_spatial_lengths(2, 3),
input_spatial_lengths(2, 71), input_spatial_lengths(2, 71),
conv_filter_strides(2, 2), conv_filter_strides(2, 2),
conv_filter_dilations(2, 1), conv_filter_dilations(2, 1),
input_left_pads(2, 1), input_left_pads(2, 1),
input_right_pads(2, 1) input_right_pads(2, 1)
{ {
} }
...@@ -77,9 +77,9 @@ ConvParams::ConvParams(ck::index_t n_dim, ...@@ -77,9 +77,9 @@ ConvParams::ConvParams(ck::index_t n_dim,
conv_filter_dilations.size() != num_dim_spatial || conv_filter_dilations.size() != num_dim_spatial ||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
{ {
throw(std::runtime_error( throw(
"ConvParams::GetOutputSpatialLengths: " std::runtime_error("ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!")); "parameter size is different from number of declared dimensions!"));
} }
} }
...@@ -91,9 +91,9 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const ...@@ -91,9 +91,9 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
conv_filter_dilations.size() != num_dim_spatial || conv_filter_dilations.size() != num_dim_spatial ||
input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial)
{ {
throw(std::runtime_error( throw(
"ConvParams::GetOutputSpatialLengths: " std::runtime_error("ConvParams::GetOutputSpatialLengths: "
"parameter size is different from number of declared dimensions!")); "parameter size is different from number of declared dimensions!"));
} }
std::vector<ck::index_t> out_spatial_len(num_dim_spatial, 0); std::vector<ck::index_t> out_spatial_len(num_dim_spatial, 0);
...@@ -101,8 +101,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const ...@@ -101,8 +101,7 @@ std::vector<ck::index_t> ConvParams::GetOutputSpatialLengths() const
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t idx_eff = const ck::index_t idx_eff = (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
(filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
out_spatial_len[i] = out_spatial_len[i] =
(input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) /
conv_filter_strides[i] + conv_filter_strides[i] +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment