Commit 8ce8f734 authored by Chao Liu's avatar Chao Liu
Browse files

pass tensor descriptor from host to device by reference, pointer and void*

parent e1eea81a
...@@ -949,7 +949,10 @@ struct DynamicFreeze ...@@ -949,7 +949,10 @@ struct DynamicFreeze
__host__ __device__ constexpr DynamicFreeze() = default; __host__ __device__ constexpr DynamicFreeze() = default;
__host__ __device__ constexpr DynamicFreeze(const index_t& low_idx) : low_idx_{low_idx} {} __host__ __device__ constexpr DynamicFreeze(const index_t& low_idx)
: low_idx_{make_multi_index(low_idx)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
......
...@@ -16,6 +16,9 @@ template <index_t BlockSize, ...@@ -16,6 +16,9 @@ template <index_t BlockSize,
typename Float, typename Float,
typename AccFloat, typename AccFloat,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -74,16 +77,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -74,16 +77,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
template <typename... ADesc, template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
typename... BDesc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
typename... CDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, Float* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -466,16 +465,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -466,16 +465,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
} }
} }
template <typename... ADesc, // pass tensor descriptor by reference
typename... BDesc, template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
typename... CDesc, __device__ void Run(const AGlobalDesc& a_k_m_global_desc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
...@@ -494,6 +490,57 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -494,6 +490,57 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *p_a_k_m_global_desc;
const auto b_k_n_global_desc = *p_b_k_n_global_desc;
const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc;
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const void* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc);
const auto b_k_n_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_m0_m1_n0_n1_global_desc);
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
}; };
} // namespace ck } // namespace ck
......
...@@ -263,36 +263,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -263,36 +263,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmN1>{}; GemmCThreadTransferDstScalarPerVector_GemmN1>{};
for(index_t i = 0; i < 5; ++i) conv_driver.Run(wei_k_c_y_x_desc,
{ in_n_c_hi_wi_desc,
std::cout << "Start running " << nrepeat << " times..." << std::endl; out_n_k_ho_wo_desc,
conv_strides,
KernelTimer timer; conv_dilations,
timer.Start(); in_left_pads,
in_right_pads,
for(index_t j = 0; j < nrepeat; ++j) static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
{ static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
conv_driver.Run(wei_k_c_y_x_desc, static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
} }
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
//#include "device_dummy_static_transform.hpp" #include "device_dummy_static_transform.hpp"
//#include "device_dummy_dynamic_transform_v1.hpp" #include "device_dummy_dynamic_transform_v1.hpp"
//#include "device_dummy_dynamic_transform.hpp" #include "device_dummy_dynamic_transform.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment