"tests/vscode:/vscode.git/clone" did not exist on "2bb2a1343345fb1870ccce9ab8216261d1e5f431"
Commit 0404f777 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 08c7f743
......@@ -5,7 +5,7 @@
#include "nvToolsExt.h"
#include "tensor.hpp"
#include "constant_tensor_descriptor.cuh"
#include "direct_convolution.cuh"
#include "direct_convolution_2.cuh"
template <class T>
struct GeneratorConstant
......@@ -133,8 +133,8 @@ void device_convolution(
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr unsigned NPerBlock = 1;
constexpr unsigned KPerBlock = 1;
constexpr unsigned CPerBlockLoop = 1;
constexpr unsigned KPerBlock = 2;
constexpr unsigned CPerBlockLoop = 4;
constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned YPerBlock = 8;
......@@ -213,7 +213,7 @@ int main()
constexpr unsigned C = 256;
constexpr unsigned HI = 34;
constexpr unsigned WI = 34;
constexpr unsigned K = 56;
constexpr unsigned K = 64;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0
......
......@@ -62,13 +62,9 @@ __device__ void blockwise_4d_tensor_op(
{
for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3)
{
const unsigned sindex =
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned sindex = src_desc.Get1dIndex(did0, did1, did2, did3);
const unsigned dindex =
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_src[dindex], p_dst[sindex]);
......@@ -115,6 +111,8 @@ __device__ void blockwise_4d_tensor_op(
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
constexpr auto desc = make_ConstantTensorDescriptor(src_desc.GetLengths());
#if 0
if(threadIdx.x == 0)
{
......@@ -125,29 +123,27 @@ __device__ void blockwise_4d_tensor_op(
unsigned lid = threadIdx.x;
for(unsigned i = lid; i < src_desc.GetElementSize(); i += BlockSize)
for(unsigned i = lid; i < desc.GetElementSize(); i += BlockSize)
{
unsigned is = i;
const unsigned did0 = is / src_desc.GetStride(I0);
const unsigned did0 = is / desc.GetStride(I0);
is -= did0 * src_desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const unsigned did1 = is / src_desc.GetStride(I1);
const unsigned did1 = is / desc.GetStride(I1);
is -= did1 * src_desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const unsigned did2 = is / src_desc.GetStride(I2);
const unsigned did2 = is / desc.GetStride(I2);
is -= did2 * src_desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const unsigned did3 = is / src_desc.GetStride(I3);
const unsigned sindex = src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned sindex = src_desc.Get1dIndex(did0, did1, did2, did3);
const unsigned dindex = dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
f(p_src[sindex], p_dst[dindex]);
}
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment