Commit 8ce8f734 authored by Chao Liu's avatar Chao Liu
Browse files

pass tensor descriptor from host to device by reference, pointer and void*

parent e1eea81a
......@@ -949,7 +949,10 @@ struct DynamicFreeze
__host__ __device__ constexpr DynamicFreeze() = default;
__host__ __device__ constexpr DynamicFreeze(const index_t& low_idx) : low_idx_{low_idx} {}
__host__ __device__ constexpr DynamicFreeze(const index_t& low_idx)
: low_idx_{make_multi_index(low_idx)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
......
......@@ -16,6 +16,9 @@ template <index_t BlockSize,
typename Float,
typename AccFloat,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -74,16 +77,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
template <typename... ADesc,
typename... BDesc,
typename... CDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
......@@ -466,16 +465,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
}
}
template <typename... ADesc,
typename... BDesc,
typename... CDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
......@@ -494,6 +490,57 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *p_a_k_m_global_desc;
const auto b_k_n_global_desc = *p_b_k_n_global_desc;
const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc;
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const void* p_b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc);
const auto b_k_n_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_m0_m1_n0_n1_global_desc);
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
......
......@@ -263,15 +263,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
conv_driver.Run(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
......@@ -282,17 +273,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
......@@ -11,12 +11,12 @@
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
//#include "device_dummy_static_transform.hpp"
//#include "device_dummy_dynamic_transform_v1.hpp"
//#include "device_dummy_dynamic_transform.hpp"
#include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp"
#include "device_dummy_dynamic_transform.hpp"
int main(int argc, char* argv[])
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment