Commit 8b5e63ed authored by Jing Zhang's avatar Jing Zhang
Browse files

mock up

parent 95a5af02
......@@ -10,8 +10,9 @@
namespace ck {
template <index_t BlockSize,
typename Float,
typename AccFloat,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
......@@ -42,9 +43,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -166,8 +167,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize,
Float,
AccFloat,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc),
......@@ -227,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const Float*,
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
Float*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
......@@ -254,11 +256,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const Float*,
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
Float*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
......@@ -281,11 +283,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const Float*,
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
Float*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
......@@ -308,11 +310,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const Float*,
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
Float*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
......
......@@ -12,8 +12,9 @@
namespace ck {
template <index_t BlockSize,
typename Float,
typename AccFloat,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
......@@ -64,17 +65,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(Float);
return a_block_space_size * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
......@@ -177,8 +178,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder,
Float,
Float,
FloatAB,
FloatAB,
decltype(a_e_k_global_desc),
decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder,
......@@ -203,8 +204,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float,
Float,
FloatAB,
FloatAB,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
......@@ -218,10 +219,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
Float* p_a_block = p_shared_block;
FloatAB* p_a_block = p_shared_block;
// register allocation for output
AccFloat p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
......@@ -240,9 +241,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2];
FloatAB p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread;
FloatAB* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS
{
......@@ -265,8 +266,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#if 1
if constexpr(HasMainKBlockLoop)
{
Float* p_b_thread_even = p_b_thread_double;
Float* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
FloatAB* p_b_thread_even = p_b_thread_double;
FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
......@@ -359,8 +360,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat,
Float,
FloatAcc,
FloatC,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
......@@ -388,17 +389,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ Float p_shared_block[shared_block_size];
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_e_k_global_desc,
p_a_global,
......@@ -414,11 +415,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
......@@ -439,11 +440,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global,
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
......
......@@ -3,7 +3,10 @@
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
template <class T,
template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc,
class WeiDesc,
class OutDesc,
......@@ -11,33 +14,31 @@ template <class T,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
InDesc,
const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc,
const Tensor<T>& wei_kcyx,
const Tensor<TInWei>& wei_k_c_y_x,
OutDesc,
Tensor<T>& out_nkhw,
Tensor<TOut>& out_n_k_ho_wo,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
<< std::endl;
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
<< std::endl;
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 0
// run-time variables
......@@ -70,18 +71,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t KPerBlock = 4;
constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16;
constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 4;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<1, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<9, 4>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......@@ -93,8 +94,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
BlockSize,
TDevice,
TDevice,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
KPerBlock,
HoPerBlock,
WoPerBlock,
......@@ -117,9 +119,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}
......@@ -80,10 +80,10 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>;
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t C = 1;
constexpr index_t HI = 1024;
constexpr index_t WI = 2048;
constexpr index_t K = 16;
constexpr index_t K = 4;
constexpr index_t Y = 3;
constexpr index_t X = 3;
......@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1
#if 0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
......@@ -754,7 +754,11 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(in_nchw_desc,
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>(
in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment