Commit 8732ea04 authored by Chao Liu's avatar Chao Liu
Browse files

tweaked params for direct conv; added a dummy winograd

parent dbffe05a
......@@ -7,6 +7,7 @@
#include "constant_tensor_descriptor.cuh"
#include "device_direct_convolution_1.cuh"
#include "device_direct_convolution_2.cuh"
#include "device_winograd_convolution.cuh"
struct GeneratorConstant
{
......@@ -395,7 +396,7 @@ int main()
Tensor<float> out_host(make_TensorDescriptor(out_desc));
Tensor<float> out_device(make_TensorDescriptor(out_desc));
#if 1
#if 0
std::size_t num_thread = std::thread::hardware_concurrency();
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
......@@ -403,16 +404,20 @@ int main()
for(int i = 0; i < 20; ++i)
{
device_direct_convolution_1(in_desc, in, wei_desc, wei, out_desc, out_device);
#if 1
device_direct_convolution_2(in_desc, in, wei_desc, wei, out_desc, out_device);
#else
device_winograd_convolution(in_desc, in, wei_desc, wei, out_desc, out_device);
#endif
}
#if 0
host_direct_convolution(in, wei, out_host);
#else
check_error(out_host, out_device);
#elif 0
host_winograd_3x3_convolution(in, wei, out_host);
#endif
check_error(out_host, out_device);
#endif
#if 0
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
......@@ -420,4 +425,4 @@ int main()
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
#endif
}
\ No newline at end of file
}
......@@ -28,7 +28,7 @@ void device_direct_convolution_2(
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 2;
constexpr unsigned CPerBlock = 4;
constexpr unsigned YPerBlock = 1;
constexpr unsigned XPerBlock = 16;
......
#pragma once
#include "gridwise_winograd_convolution.cuh"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_winograd_convolution(
InDesc, const Tensor<T>& in, WeiDesc, const Tensor<T>& wei, OutDesc, Tensor<T>& out)
{
std::size_t data_sz = sizeof(T);
DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace());
int num_thread = std::thread::hardware_concurrency();
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
out_device_buf.ToDevice(out.mData.data());
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr unsigned OutTileSizeH = 2;
constexpr unsigned OutTileSizeW = 2;
constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 16;
constexpr unsigned CPerBlock = 4;
constexpr unsigned YPerBlock = 1;
constexpr unsigned XPerBlock = 16;
constexpr unsigned NPerThread = 2;
constexpr unsigned KPerThread = 2;
constexpr unsigned CPerThread = 2;
constexpr unsigned BlockSize = 128;
constexpr unsigned GridSize = (out_desc.GetLength(I0) / NPerBlock) *
(out_desc.GetLength(I1) / KPerBlock) *
(out_desc.GetLength(I2) / (OutTileSizeH * YPerBlock)) *
(out_desc.GetLength(I3) / (OutTileSizeW * XPerBlock));
dim3 block_dim(BlockSize);
dim3 grid_dim(GridSize);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
cudaEvent_t start, stop;
float elapsedTime;
cudaEventCreate(&start);
cudaEventRecord(start, 0);
gridwise_winograd_convolution<T,
InDesc,
WeiDesc,
OutDesc,
OutTileSizeH,
OutTileSizeW,
NPerBlock,
KPerBlock,
CPerBlock,
YPerBlock,
XPerBlock,
NPerThread,
KPerThread,
CPerThread,
BlockSize,
GridSize>
<<<grid_dim, block_dim>>>(InDesc{},
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
WeiDesc{},
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
OutDesc{},
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
cudaEventCreate(&stop);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime);
checkCudaErrors(cudaGetLastError());
out_device_buf.FromDevice(out.mData.data());
}
\ No newline at end of file
#pragma once
#include "constant_tensor_descriptor.cuh"
template <class TFloat,
unsigned InTileSizeH,
unsigned InTileSizeW,
unsigned S,
unsigned R,
unsigned OutTileSizeH,
unsigned OutTileSizeW,
unsigned NPerBlock,
unsigned CPerBlock,
unsigned YPerBlock,
unsigned XPerBlock,
unsigned BlockSize>
__device__ void blockwise_winograd_transform_input(TFloat* const __restrict__ p_in,
TFloat* __restrict__ p_in_transform)
{
p_in_transform[0] = 1;
}
template <class TFloat,
unsigned InTileSizeH,
unsigned InTileSizeW,
unsigned S,
unsigned R,
unsigned OutTileSizeH,
unsigned OutTileSizeW,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned BlockSize>
__device__ void blockwise_winograd_transform_weight(TFloat* const __restrict__ p_wei,
TFloat* __restrict__ p_wei_transform)
{
p_wei_transform[0] = 1;
}
\ No newline at end of file
......@@ -214,8 +214,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
__syncthreads();
for(unsigned c_thread_data_offset = 0; c_thread_data_offset < CPerBlock;
c_thread_data_offset += CPerThread)
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// copy input tensor into register
threadwise_4d_tensor_op_binary<TFloat,
......@@ -224,7 +223,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
decltype(f_copy)>(
in_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_offset,
c_thread_data_offset,
c_thread_data,
hi_thread_data_offset,
wi_thread_data_offset),
in_thread_desc,
......@@ -237,8 +236,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
decltype(wei_thread_desc),
decltype(f_copy)>(
wei_thread_block_desc,
p_wei_block +
wei_block_desc.Get1dIndex(k_thread_data_offset, c_thread_data_offset, 0, 0),
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_offset, c_thread_data, 0, 0),
wei_thread_desc,
p_wei_thread,
f_copy);
......@@ -269,4 +267,4 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
ho_block_data_offset + ho_thread_data_offset,
wo_block_data_offset + wo_thread_data_offset),
f_copy);
}
\ No newline at end of file
}
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "blockwise_winograd_transform.cuh"
#include "threadwise_winograd_transform.cuh"
template <class TFloat,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
unsigned OutTileSizeH,
unsigned OutTileSizeW,
unsigned NPerBlock,
unsigned KPerBlock,
unsigned CPerBlock,
unsigned YPerBlock,
unsigned XPerBlock,
unsigned NPerThread,
unsigned KPerThread,
unsigned CPerThread,
unsigned BlockSize,
unsigned GridSize>
__global__ void gridwise_winograd_convolution(InGlobalDesc,
TFloat* const __restrict__ p_in_global,
WeiGlobalDesc,
TFloat* const __restrict__ p_wei_global,
OutGlobalDesc,
TFloat* __restrict__ p_out_global)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto in_global_desc = InGlobalDesc{};
constexpr auto wei_global_desc = WeiGlobalDesc{};
constexpr auto out_global_desc = OutGlobalDesc{};
constexpr unsigned S = wei_global_desc.GetLength(I2);
constexpr unsigned R = wei_global_desc.GetLength(I3);
constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;
constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;
constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;
// divide block work
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned YBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned XBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x;
unsigned itmp = block_id;
const unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork);
itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork);
const unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork);
itmp -= k_block_work_id * (YBlockWork * XBlockWork);
const unsigned y_block_work_id = itmp / XBlockWork;
const unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork;
const unsigned n_block_data_offset = n_block_work_id * NPerBlock;
const unsigned k_block_data_offset = k_block_work_id * KPerBlock;
const unsigned y_block_data_offset = y_block_work_id * YPerBlock;
const unsigned x_block_data_offset = x_block_work_id * XPerBlock;
const unsigned ho_block_data_offset = y_block_data_offset * OutTileSizeH;
const unsigned wo_block_data_offset = x_block_data_offset * OutTileSizeW;
const unsigned hi_block_data_offset = ho_block_data_offset; // minus padding
const unsigned wi_block_data_offset = wo_block_data_offset; // minus padding
// divide thread work
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr unsigned YThreadWork = YPerBlock;
constexpr unsigned XThreadWork = XPerBlock;
const unsigned thread_id = threadIdx.x;
itmp = thread_id;
const unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
const unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
const unsigned y_thread_work_id = itmp / XThreadWork;
const unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;
const unsigned n_thread_data_offset = n_thread_work_id * NPerThread;
const unsigned k_thread_data_offset = k_thread_work_id * KPerThread;
const unsigned y_thread_data_offset = y_thread_work_id;
const unsigned x_thread_data_offset = x_thread_work_id;
// op
auto f_set0 = [](TFloat& v) { v = TFloat(0); };
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
// block data
constexpr auto in_transform_block_desc = make_ConstantTensorDescriptor(
Sequence<NPerBlock, CPerBlock, YPerBlock * InTileSizeH, XPerBlock * InTileSizeW>{});
constexpr auto wei_transform_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, InTileSizeH, InTileSizeW>{});
constexpr unsigned in_transform_block_size = in_transform_block_desc.GetElementSpace();
constexpr unsigned wei_transform_block_size = wei_transform_block_desc.GetElementSpace();
__shared__ TFloat p_in_transform_block[in_transform_block_size];
__shared__ TFloat p_wei_transform_block[wei_transform_block_size];
// thread data
constexpr auto in_transform_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{},
in_transform_block_desc.GetStrides());
constexpr auto wei_transform_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<KPerThread, CPerThread, InTileSizeH, InTileSizeW>{},
wei_transform_block_desc.GetStrides());
constexpr auto out_transform_thread_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, KPerThread, InTileSizeH, InTileSizeW>{});
constexpr auto out_thread_desc = make_ConstantTensorDescriptor(
Sequence<NPerThread, KPerThread, OutTileSizeH, OutTileSizeW>{});
constexpr auto out_thread_global_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides());
constexpr unsigned out_transform_thread_size = out_transform_thread_desc.GetElementSpace();
constexpr unsigned out_thread_size = out_thread_desc.GetElementSpace();
TFloat p_out_transform_thread[out_transform_thread_size];
TFloat p_out_thread[out_thread_size];
#if 0
if(blockIdx.x == 0 && threadIdx.x == 0)
{
printf("in_transform_block_size %u, wei_transform_block_size %u, out_transform_thread_size "
"%u, out_thread_size %u \n",
in_transform_block_size,
wei_transform_block_size,
out_transform_thread_size,
out_thread_size);
}
#endif
// set threadwise output transform tensor to 0
threadwise_4d_tensor_op_unary<TFloat, decltype(out_transform_thread_desc), decltype(f_set0)>(
out_transform_thread_desc, p_out_transform_thread, f_set0);
for(unsigned c_block_data_offset = 0; c_block_data_offset < in_global_desc.GetLength(I1);
c_block_data_offset += CPerBlock, __syncthreads())
{
// blockwise transform input
blockwise_winograd_transform_input<TFloat,
InTileSizeH,
InTileSizeW,
S,
R,
OutTileSizeH,
OutTileSizeW,
NPerBlock,
CPerBlock,
YPerBlock,
XPerBlock,
BlockSize>(
p_in_global + in_global_desc.Get1dIndex(n_block_data_offset,
c_block_data_offset,
hi_block_data_offset,
wi_block_data_offset),
p_in_transform_block);
// blockwise transform weights
blockwise_winograd_transform_weight<TFloat,
InTileSizeH,
InTileSizeW,
S,
R,
OutTileSizeH,
OutTileSizeW,
KPerBlock,
CPerBlock,
BlockSize>(
p_wei_global +
wei_global_desc.Get1dIndex(k_block_data_offset, c_block_data_offset, 0, 0),
p_wei_transform_block);
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{
// threadwise point multiplication
threadwise_winograd_calculate_transformed_output<
TFloat,
decltype(in_transform_thread_block_desc),
decltype(wei_transform_thread_block_desc),
decltype(out_transform_thread_desc),
InTileSizeH,
InTileSizeW,
S,
R,
OutTileSizeH,
OutTileSizeW>(
in_transform_thread_block_desc,
p_in_transform_block +
in_transform_block_desc.Get1dIndex(n_thread_data_offset,
c_thread_data,
y_thread_data_offset * InTileSizeH,
x_thread_data_offset * InTileSizeW),
wei_transform_thread_block_desc,
p_wei_transform_block +
wei_transform_block_desc.Get1dIndex(k_thread_data_offset, c_thread_data, 0, 0),
out_transform_thread_desc,
p_out_transform_thread);
}
};
// transform back
threadwise_winograd_reverse_transform_output<TFloat,
decltype(out_transform_thread_desc),
decltype(out_thread_desc),
InTileSizeH,
InTileSizeW,
S,
R,
OutTileSizeH,
OutTileSizeW>(
out_transform_thread_desc, p_out_transform_thread, out_thread_desc, p_out_thread);
// copy output tensor from register to global mem
threadwise_4d_tensor_op_binary<TFloat,
decltype(out_thread_desc),
decltype(out_thread_global_desc),
decltype(f_copy)>(
out_thread_desc,
p_out_thread,
out_thread_global_desc,
p_out_global +
out_global_desc.Get1dIndex(n_block_data_offset + n_thread_data_offset,
k_block_data_offset + k_thread_data_offset,
ho_block_data_offset + y_thread_data_offset * OutTileSizeH,
wo_block_data_offset + x_thread_data_offset * OutTileSizeW),
f_copy);
}
\ No newline at end of file
......@@ -44,17 +44,11 @@ __device__ void threadwise_direct_convolution(InDesc,
const unsigned hi = ho + s;
const unsigned wi = wo + r;
const unsigned in_index =
in_desc.GetStride(I0) * n + in_desc.GetStride(I1) * c +
in_desc.GetStride(I2) * hi + in_desc.GetStride(I3) * wi;
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi);
const unsigned wei_index =
wei_desc.GetStride(I0) * k + wei_desc.GetStride(I1) * c +
wei_desc.GetStride(I2) * s + in_desc.GetStride(I3) * r;
const unsigned wei_index = wei_desc.Get1dIndex(k, c, s, r);
const unsigned out_index =
out_desc.GetStride(I0) * n + out_desc.GetStride(I1) * k +
out_desc.GetStride(I2) * ho + out_desc.GetStride(I3) * wo;
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
p_out[out_index] += p_wei[wei_index] * p_in[in_index];
......
#pragma once
#include "constant_tensor_descriptor.cuh"
template <class TFloat,
class InTransThreadDesc, //{NPerThread, CPerThread, InTileSizeH, InTileSizeW}
class WeiTransThreadDesc, //{KPerThread, CPerThread, InTileSizeH, InTileSizeW}
class OutTransThreadDesc, //{NPerThread, KPerThread, InTileSizeH, InTileSizeW}
unsigned InTileSizeH,
unsigned InTileSizeW,
unsigned S,
unsigned R,
unsigned OutTileSizeH,
unsigned OutTileSizeW>
__device__ void
threadwise_winograd_calculate_transformed_output(InTransThreadDesc,
TFloat* const __restrict__ p_in_transform_thread,
WeiTransThreadDesc,
TFloat* const __restrict__ p_wei_transform_thread,
OutTransThreadDesc,
TFloat* __restrict__ p_out_transform_thread)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto in_transform_thread_desc = InTransThreadDesc{};
constexpr auto wei_transform_thread_desc = WeiTransThreadDesc{};
constexpr auto out_transform_thread_desc = OutTransThreadDesc{};
for(unsigned n = 0; n < out_transform_thread_desc.GetLength(I0); ++n)
{
for(unsigned k = 0; k < out_transform_thread_desc.GetLength(I1); ++k)
{
for(unsigned h = 0; h < out_transform_thread_desc.GetLength(I2); ++h)
{
for(unsigned w = 0; w < out_transform_thread_desc.GetLength(I3); ++w)
{
for(unsigned c = 0; c < wei_transform_thread_desc.GetLength(I1); ++c)
{
const unsigned in_index = in_transform_thread_desc.Get1dIndex(n, c, h, w);
const unsigned wei_index = wei_transform_thread_desc.Get1dIndex(k, c, h, w);
const unsigned out_index = out_transform_thread_desc.Get1dIndex(n, k, h, w);
p_out_transform_thread[out_index] +=
p_wei_transform_thread[wei_index] * p_in_transform_thread[in_index];
}
}
}
}
}
}
template <class TFloat,
class OutTransThreadDesc, //{NPerThread, KPerThread, InTileSizeH, InTileSizeW}
class OutThreadDesc, //{NPerThread, CPerThread, OutTileSizeH, OutTileSizeW}
unsigned InTileSizeH,
unsigned InTileSizeW,
unsigned S,
unsigned R,
unsigned OutTileSizeH,
unsigned OutTileSizeW>
__device__ void
threadwise_winograd_reverse_transform_output(OutTransThreadDesc,
TFloat* const __restrict__ p_out_transform_thread,
OutThreadDesc,
TFloat* __restrict__ p_out_thread)
{
static_assert(InTileSizeH == 4, "wrong");
static_assert(InTileSizeW == 4, "wrong");
static_assert(S == 3, "wrong");
static_assert(R == 3, "wrong");
static_assert(OutTileSizeH == 2, "wrong");
static_assert(OutTileSizeW == 2, "wrong");
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto out_transform_thread_desc = OutTransThreadDesc{};
constexpr auto out_thread_desc = OutThreadDesc{};
static_assert(InTileSizeH == out_transform_thread_desc.GetLength(I2), "wrong");
static_assert(InTileSizeW == out_transform_thread_desc.GetLength(I3), "wrong");
static_assert(OutTileSizeH == out_thread_desc.GetLength(I2), "wrong");
static_assert(OutTileSizeW == out_thread_desc.GetLength(I3), "wrong");
for(unsigned n = 0; n < out_thread_desc.GetLength(I0); ++n)
{
for(unsigned k = 0; k < out_thread_desc.GetLength(I1); ++k)
{
p_out_thread[out_thread_desc.Get1dIndex(n, k, 0, 0)] =
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 0)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 1)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 2)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 0)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 0)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)];
p_out_thread[out_thread_desc.Get1dIndex(n, k, 0, 1)] =
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 1)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 2)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 0, 3)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 3)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 3)];
p_out_thread[out_thread_desc.Get1dIndex(n, k, 1, 0)] =
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 0)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 0)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 0)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 1)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 2)];
p_out_thread[out_thread_desc.Get1dIndex(n, k, 1, 1)] =
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 1)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 2)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 1, 3)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 1)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 2)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 2, 3)] -
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 1)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 2)] +
p_out_transform_thread[out_transform_thread_desc.Get1dIndex(n, k, 3, 3)];
}
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment