Commit a951c345 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

added skeleton files

parent 52423948
#ifndef CK_GRIDWISE_REDUX_KERNEL_WRAPPER
#define CK_GRIDWISE_REDUX_KERNEL_WRAPPER
template <class GridwiseRedux, class T>
__global__ void run_gridwise_redux_kernel(const T* const __restrict__ p_in_global,
T* const __restrict__ p_out_global)
{
GridwiseRedux{}.Run(p_in_global, p_out_global);
}
#endif
#ifndef CK_GRIDWISE_TENSOR_REDUX_v1
#define CK_GRIDWISE_TENSOR_REDUX_v1
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class OutGlobalDesc,
class ReduxDims
#if 0
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K
#endif
>
struct GridwiseTensorRedux_v1
{
__device__ void Run(const Float* const __restrict__ p_in_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto out_k_h_w_global_desc = OutGlobalDesc{};
constexpr auto N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr auto C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr auto H = in_n_c_h_w_global_desc.GetLength(I2);
constexpr auto W = in_n_c_h_w_global_desc.GetLength(I3);
constexpr auto total_elems = N * C * H * W;
// constexpr auto out_k0_k1_k2_n1_h_w_thread_mem_desc =
// make_ConstantTensorDescriptor_packed(
// Sequence<1>{});
// Float p_out_thread[out_k0_k1_k2_n1_h_w_thread_mem_desc.GetElementSpace()];
Float p_out_thread[1];
// TODO: assert that except the reduced dimension all sizes are the same
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(Sequence<1, 1, 1, BlockSize>{});
const auto thread_cluster_id = thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
constexpr auto block_cluster_desc = make_ConstantTensorDescriptor_packed(Sequence<1,1,1,total_elems / BlockSize>{});
const auto block_cluster_id = block_cluster_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
{
const Float* p_in_thread_on_global =
p_in_global +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(block_cluster_id + thread_cluster_id);
// constexpr auto threadwise_in_copy = ThreadwiseGenericTensorSliceCopy_v2<
// decltype(in_n_c_h_w_global_desc), //source
// decltype(out_k0_k1_k2_n1_h_w_thread_mem_desc),
// NormalTensorCoordinate<decltype(in_n_c_h_w_global_desc)>, //source
// NormalTensorCoordinate<decltype(out_k0_k1_k2_n1_h_w_thread_mem_desc)>,
// decltype(in_n_c_h_w_global_desc.GetLengths())>(); //source
// threadwise_in_copy.Run(p_in_thread_on_global, p_out_thread);
printf("block: (%d, %d), thread: (%d, %d), input: %f\n", block_cluster_id[2], block_cluster_id[3], thread_cluster_id[2], thread_cluster_id[3], *p_in_thread_on_global);
p_out_thread[0] = p_in_thread_on_global[0];
}
{
Float* p_out_thread_on_global =
p_out_global ; //+
//out_k_h_w_global_desc.GetOffsetFromMultiIndex(
// get_thread_local_1d_id());
// constexpr auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
// decltype(out_k0_k1_k2_n1_h_w_thread_mem_desc),
// decltype(out_k_h_w_global_desc),
// NormalTensorCoordinate<decltype(out_k0_k1_k2_n1_h_w_thread_mem_desc)>,
// NormalTensorCoordinate<decltype(out_k_h_w_global_desc)>,
// decltype(out_k0_k1_k2_n1_h_w_thread_mem_desc.GetLengths())>();
// threadwise_out_copy.Run(p_out_thread, p_out_thread_on_global);
auto idx = get_thread_local_1d_id();
p_out_thread_on_global[idx] = p_out_thread[0];
}
}
};
} // namespace ck
#endif
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_redux_kernel_wrapper.hpp"
//#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "gridwise_tensor_redux.hpp"
template <class T,
class InDesc,
class OutDesc
>
void device_tensor_redux(InDesc,
const Tensor<T>& in_nchw,
OutDesc,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t C = in_nchw_desc.GetLength(I1);
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
constexpr index_t N = out_nkhw_desc.GetLength(I0);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t BlockSize = 256;
constexpr auto GridSize = (N * C * Hi * Wi) / BlockSize;
constexpr auto redux_dim = Sequence<I0>{};
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_redux =
GridwiseTensorRedux_v1
<GridSize,
BlockSize,
T,
decltype(in_nchw_desc),
decltype(out_nkhw_desc), decltype(redux_dim)
>{};
float time = launch_kernel(run_gridwise_redux_kernel<decltype(gridwise_redux), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<const T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n",
time);
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment