Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
08c7f743
Commit
08c7f743
authored
Nov 07, 2018
by
Chao Liu
Browse files
add 2nd version of blockwise_tensor_op
parent
5d2cafcb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
309 additions
and
205 deletions
+309
-205
driver/conv.cu
driver/conv.cu
+29
-14
src/include/blockwise_tensor_op.cuh
src/include/blockwise_tensor_op.cuh
+155
-0
src/include/common.cuh
src/include/common.cuh
+13
-0
src/include/constant_tensor_descriptor.cuh
src/include/constant_tensor_descriptor.cuh
+1
-1
src/include/direct_convolution.cuh
src/include/direct_convolution.cuh
+50
-190
src/include/threadwise_tensor_op.cuh
src/include/threadwise_tensor_op.cuh
+61
-0
No files found.
driver/conv.cu
View file @
08c7f743
...
@@ -137,23 +137,31 @@ void device_convolution(
...
@@ -137,23 +137,31 @@ void device_convolution(
constexpr
unsigned
CPerBlockLoop
=
1
;
constexpr
unsigned
CPerBlockLoop
=
1
;
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeH
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
OutTileSizeW
=
2
;
constexpr
unsigned
YPerBlock
=
4
;
constexpr
unsigned
YPerBlock
=
8
;
constexpr
unsigned
XPerBlock
=
8
;
constexpr
unsigned
XPerBlock
=
16
;
constexpr
unsigned
NBlockCopyLen0
=
1
;
constexpr
unsigned
NBlockCopyLen0
=
1
;
constexpr
unsigned
NBlockCopyLen1
=
1
;
constexpr
unsigned
NBlockCopyLen1
=
1
;
constexpr
unsigned
NBlockCopyLen2
=
2
;
constexpr
unsigned
NBlockCopyLen2
=
2
;
constexpr
unsigned
NBlockCopyLen3
=
16
;
constexpr
unsigned
NBlockCopyLen3
=
16
;
constexpr
unsigned
nblock
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
constexpr
unsigned
BlockSize
=
128
;
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
(
OutTileSizeH
*
YPerBlock
))
*
(
out_desc
.
GetLength
(
I3
)
/
(
OutTileSizeW
*
XPerBlock
));
dim3
block_dim
(
32
);
constexpr
unsigned
GridSize
=
(
out_desc
.
GetLength
(
I0
)
/
NPerBlock
)
*
dim3
grid_dim
(
nblock
);
(
out_desc
.
GetLength
(
I1
)
/
KPerBlock
)
*
(
out_desc
.
GetLength
(
I2
)
/
(
OutTileSizeH
*
YPerBlock
))
*
(
out_desc
.
GetLength
(
I3
)
/
(
OutTileSizeW
*
XPerBlock
));
printf
(
"__func__: nblock %u
\n
"
,
nblock
);
dim3
block_dim
(
BlockSize
);
dim3
grid_dim
(
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
cudaEvent_t
start
,
stop
;
float
elapsedTime
;
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
gridwise_convolution
<
T
,
gridwise_convolution
<
T
,
InDesc
,
InDesc
,
...
@@ -169,7 +177,9 @@ void device_convolution(
...
@@ -169,7 +177,9 @@ void device_convolution(
NBlockCopyLen0
,
NBlockCopyLen0
,
NBlockCopyLen1
,
NBlockCopyLen1
,
NBlockCopyLen2
,
NBlockCopyLen2
,
NBlockCopyLen3
>
NBlockCopyLen3
,
BlockSize
,
GridSize
>
<<<
grid_dim
,
block_dim
>>>
(
InDesc
{},
<<<
grid_dim
,
block_dim
>>>
(
InDesc
{},
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_device_buf
.
GetDeviceBuffer
()),
WeiDesc
{},
WeiDesc
{},
...
@@ -177,6 +187,13 @@ void device_convolution(
...
@@ -177,6 +187,13 @@ void device_convolution(
OutDesc
{},
OutDesc
{},
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_device_buf
.
GetDeviceBuffer
()));
cudaEventCreate
(
&
stop
);
cudaEventRecord
(
stop
,
0
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
elapsedTime
,
start
,
stop
);
printf
(
"Elapsed time : %f ms
\n
"
,
elapsedTime
);
checkCudaErrors
(
cudaGetLastError
());
checkCudaErrors
(
cudaGetLastError
());
out_device_buf
.
FromDevice
(
out
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out
.
mData
.
data
());
}
}
...
@@ -231,7 +248,7 @@ int main()
...
@@ -231,7 +248,7 @@ int main()
int
num_thread
=
std
::
thread
::
hardware_concurrency
();
int
num_thread
=
std
::
thread
::
hardware_concurrency
();
#if
1
#if
0
in.GenerateTensorValue(GeneratorTensor<float>{}, num_thread);
in.GenerateTensorValue(GeneratorTensor<float>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor<float>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor<float>{}, num_thread);
out_host.GenerateTensorValue(GeneratorConstant<float>{0}, num_thread);
out_host.GenerateTensorValue(GeneratorConstant<float>{0}, num_thread);
...
@@ -241,9 +258,7 @@ int main()
...
@@ -241,9 +258,7 @@ int main()
device_convolution
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
device_convolution
(
in_desc
,
in
,
wei_desc
,
wei
,
out_desc
,
out_device
);
std
::
cout
<<
__func__
<<
": done"
<<
std
::
endl
;
#if 0
#if 1
host_convolution(in, wei, out_host);
host_convolution(in, wei, out_host);
float error = 0;
float error = 0;
...
...
src/include/blockwise_tensor_op.cuh
0 → 100644
View file @
08c7f743
#pragma once
#include "constant_tensor_descriptor.cuh"
#if 0
template <class TFloat,
class SrcDesc,
class DstDesc,
unsigned NWorkLen0,
unsigned NWorkLen1,
unsigned NWorkLen2,
unsigned NWorkLen3,
class F,
unsigned BlockSize>
__device__ void blockwise_4d_tensor_op(
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
}
#endif
constexpr unsigned NWorkStride3 = 1;
constexpr unsigned NWorkStride2 = NWorkLen3 * NWorkStride3;
constexpr unsigned NWorkStride1 = NWorkLen2 * NWorkStride2;
constexpr unsigned NWorkStride0 = NWorkLen1 * NWorkStride1;
unsigned itmp =
threadIdx.x;
const unsigned did0_begin = itmp / NWorkStride0;
itmp -= did0_begin * NWorkStride0;
const unsigned did1_begin = itmp / NWorkStride1;
itmp -= did1_begin * NWorkStride1;
const unsigned did2_begin = itmp / NWorkStride2;
itmp -= did2_begin * NWorkStride2;
const unsigned did3_begin = itmp / NWorkStride3;
for(unsigned did0 = did0_begin; did0 < src_desc.GetLength(I0); did0 += NWorkLen0)
{
for(unsigned did1 = did1_begin; did1 < src_desc.GetLength(I1); did1 += NWorkLen1)
{
for(unsigned did2 = did2_begin; did2 < src_desc.GetLength(I2); did2 += NWorkLen2)
{
for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3)
{
const unsigned sindex =
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned dindex =
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
f(p_src[dindex], p_dst[sindex]);
#if 0
// if(threadIdx.x == 0)
{
printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
}
}
}
}
}
#elif
1
template
<
class
TFloat
,
class
SrcDesc
,
class
DstDesc
,
unsigned
NWorkLen0
,
unsigned
NWorkLen1
,
unsigned
NWorkLen2
,
unsigned
NWorkLen3
,
class
F
,
unsigned
BlockSize
>
__device__
void
blockwise_4d_tensor_op
(
SrcDesc
,
TFloat
*
const
__restrict__
p_src
,
DstDesc
,
TFloat
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Index
<
0
>
{};
constexpr
auto
I1
=
Index
<
1
>
{};
constexpr
auto
I2
=
Index
<
2
>
{};
constexpr
auto
I3
=
Index
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
static_assert
(
is_same
<
decltype
(
src_desc
.
GetLengths
()),
decltype
(
dst_desc
.
GetLengths
())
>::
value
);
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
}
#endif
unsigned
lid
=
threadIdx
.
x
;
for
(
unsigned
i
=
lid
;
i
<
src_desc
.
GetElementSize
();
i
+=
BlockSize
)
{
unsigned
is
=
i
;
const
unsigned
did0
=
is
/
src_desc
.
GetStride
(
I0
);
is
-=
did0
*
src_desc
.
GetStride
(
I0
);
const
unsigned
did1
=
is
/
src_desc
.
GetStride
(
I1
);
is
-=
did1
*
src_desc
.
GetStride
(
I1
);
const
unsigned
did2
=
is
/
src_desc
.
GetStride
(
I2
);
is
-=
did2
*
src_desc
.
GetStride
(
I2
);
const
unsigned
did3
=
is
/
src_desc
.
GetStride
(
I3
);
const
unsigned
sindex
=
src_desc
.
GetStride
(
I0
)
*
did0
+
src_desc
.
GetStride
(
I1
)
*
did1
+
src_desc
.
GetStride
(
I2
)
*
did2
+
src_desc
.
GetStride
(
I3
)
*
did3
;
const
unsigned
dindex
=
dst_desc
.
GetStride
(
I0
)
*
did0
+
dst_desc
.
GetStride
(
I1
)
*
did1
+
dst_desc
.
GetStride
(
I2
)
*
did2
+
dst_desc
.
GetStride
(
I3
)
*
did3
;
f
(
p_src
[
sindex
],
p_dst
[
dindex
]);
}
}
#endif
src/include/common.cuh
0 → 100644
View file @
08c7f743
#pragma once
template
<
class
T1
,
class
T2
>
struct
is_same
{
static
const
bool
value
=
false
;
};
template
<
class
T
>
struct
is_same
<
T
,
T
>
{
static
const
bool
value
=
true
;
};
src/include/constant_tensor_descriptor.cuh
View file @
08c7f743
#pragma once
#pragma once
#include "
helper_cuda.
h"
#include "
common.cu
h"
template
<
class
T
,
T
N
>
template
<
class
T
,
T
N
>
struct
Constant
struct
Constant
...
...
src/include/direct_convolution.cuh
View file @
08c7f743
#pragma once
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "constant_tensor_descriptor.cuh"
#include "blockwise_tensor_op.cuh"
template
<
class
TFloat
,
#include "threadwise_tensor_op.cuh"
class
SrcDesc
,
class
DstDesc
,
unsigned
NWorkLen0
,
unsigned
NWorkLen1
,
unsigned
NWorkLen2
,
unsigned
NWorkLen3
,
class
F
>
__device__
void
blockwise_4d_tensor_op
(
SrcDesc
,
TFloat
*
const
__restrict__
p_src
,
DstDesc
,
TFloat
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Index
<
0
>
{};
constexpr
auto
I1
=
Index
<
1
>
{};
constexpr
auto
I2
=
Index
<
2
>
{};
constexpr
auto
I3
=
Index
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
}
#endif
constexpr
unsigned
NWorkStride3
=
1
;
constexpr
unsigned
NWorkStride2
=
NWorkLen3
*
NWorkStride3
;
constexpr
unsigned
NWorkStride1
=
NWorkLen2
*
NWorkStride2
;
constexpr
unsigned
NWorkStride0
=
NWorkLen1
*
NWorkStride1
;
unsigned
itmp
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
z
*
(
blockDim
.
y
*
blockDim
.
x
);
const
unsigned
did0_begin
=
itmp
/
NWorkStride0
;
itmp
-=
did0_begin
*
NWorkStride0
;
const
unsigned
did1_begin
=
itmp
/
NWorkStride1
;
itmp
-=
did1_begin
*
NWorkStride1
;
const
unsigned
did2_begin
=
itmp
/
NWorkStride2
;
itmp
-=
did2_begin
*
NWorkStride2
;
const
unsigned
did3_begin
=
itmp
/
NWorkStride3
;
for
(
unsigned
did0
=
did0_begin
;
did0
<
src_desc
.
GetLength
(
I0
);
did0
+=
NWorkLen0
)
{
for
(
unsigned
did1
=
did1_begin
;
did1
<
src_desc
.
GetLength
(
I1
);
did1
+=
NWorkLen1
)
{
for
(
unsigned
did2
=
did2_begin
;
did2
<
src_desc
.
GetLength
(
I2
);
did2
+=
NWorkLen2
)
{
for
(
unsigned
did3
=
did3_begin
;
did3
<
src_desc
.
GetLength
(
I3
);
did3
+=
NWorkLen3
)
{
const
unsigned
sindex
=
src_desc
.
GetStride
(
I0
)
*
did0
+
src_desc
.
GetStride
(
I1
)
*
did1
+
src_desc
.
GetStride
(
I2
)
*
did2
+
src_desc
.
GetStride
(
I3
)
*
did3
;
const
unsigned
dindex
=
dst_desc
.
GetStride
(
I0
)
*
did0
+
dst_desc
.
GetStride
(
I1
)
*
did1
+
dst_desc
.
GetStride
(
I2
)
*
did2
+
dst_desc
.
GetStride
(
I3
)
*
did3
;
f
(
p_src
[
dindex
],
p_dst
[
sindex
]);
#if 0
// if(threadIdx.x == 0)
{
printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
}
}
}
}
}
template
<
class
TFloat
,
class
SrcDesc
,
class
DstDesc
,
class
F
>
__device__
void
threadwise_4d_tensor_op
(
SrcDesc
,
TFloat
*
const
__restrict__
p_src
,
DstDesc
,
TFloat
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Index
<
0
>
{};
constexpr
auto
I1
=
Index
<
1
>
{};
constexpr
auto
I2
=
Index
<
2
>
{};
constexpr
auto
I3
=
Index
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc);
print_ConstantTensorDescriptor(dst_desc);
}
#endif
for
(
unsigned
did0
=
0
;
did0
<
src_desc
.
GetLength
(
I0
);
++
did0
)
{
for
(
unsigned
did1
=
0
;
did1
<
src_desc
.
GetLength
(
I1
);
++
did1
)
{
for
(
unsigned
did2
=
0
;
did2
<
src_desc
.
GetLength
(
I2
);
++
did2
)
{
for
(
unsigned
did3
=
0
;
did3
<
src_desc
.
GetLength
(
I3
);
++
did3
)
{
const
unsigned
sindex
=
src_desc
.
GetStride
(
I0
)
*
did0
+
src_desc
.
GetStride
(
I1
)
*
did1
+
src_desc
.
GetStride
(
I2
)
*
did2
+
src_desc
.
GetStride
(
I3
)
*
did3
;
const
unsigned
dindex
=
dst_desc
.
GetStride
(
I0
)
*
did0
+
dst_desc
.
GetStride
(
I1
)
*
did1
+
dst_desc
.
GetStride
(
I2
)
*
did2
+
dst_desc
.
GetStride
(
I3
)
*
did3
;
f
(
p_src
[
sindex
],
p_dst
[
dindex
]);
#if 0
if(threadIdx.x == 0)
{
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
}
}
}
}
}
template
<
class
TFloat
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
TFloat
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
__device__
void
threadwise_direct_convolution
(
InDesc
,
__device__
void
threadwise_direct_convolution
(
InDesc
,
...
@@ -232,7 +91,8 @@ template <class TFloat,
...
@@ -232,7 +91,8 @@ template <class TFloat,
class
WeiDesc
,
class
WeiDesc
,
class
OutDesc
,
class
OutDesc
,
unsigned
OutTileSizeH
,
unsigned
OutTileSizeH
,
unsigned
OutTileSizeW
>
unsigned
OutTileSizeW
,
unsigned
BlockSize
>
__device__
void
blockwise_convolution
(
InDesc
,
__device__
void
blockwise_convolution
(
InDesc
,
TFloat
*
const
__restrict__
p_in
,
TFloat
*
const
__restrict__
p_in
,
WeiDesc
,
WeiDesc
,
...
@@ -290,14 +150,11 @@ __device__ void blockwise_convolution(InDesc,
...
@@ -290,14 +150,11 @@ __device__ void blockwise_convolution(InDesc,
constexpr
auto
out_thread_dst_desc
=
constexpr
auto
out_thread_dst_desc
=
make_ConstantTensorDescriptor
(
out_thread_src_desc
.
GetLengths
());
make_ConstantTensorDescriptor
(
out_thread_src_desc
.
GetLengths
());
const
unsigned
thread_sz
=
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
;
const
unsigned
thread_id
=
threadIdx
.
x
;
const
unsigned
thread_id
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
z
*
(
blockDim
.
y
*
blockDim
.
x
);
for
(
unsigned
thread_work_id
=
thread_id
;
for
(
unsigned
thread_work_id
=
thread_id
;
thread_work_id
<
NPerBlock
*
KPerBlock
*
YPerBlock
*
XPerBlock
;
thread_work_id
<
NPerBlock
*
KPerBlock
*
YPerBlock
*
XPerBlock
;
thread_work_id
+=
thread_sz
)
thread_work_id
+=
BlockSize
)
{
{
unsigned
itmp
=
thread_work_id
;
unsigned
itmp
=
thread_work_id
;
unsigned
n_thread_work_id
=
itmp
/
(
KPerBlock
*
YPerBlock
*
XPerBlock
);
unsigned
n_thread_work_id
=
itmp
/
(
KPerBlock
*
YPerBlock
*
XPerBlock
);
...
@@ -397,7 +254,9 @@ template <class TFloat,
...
@@ -397,7 +254,9 @@ template <class TFloat,
unsigned
NBlockCopyLen0
,
unsigned
NBlockCopyLen0
,
unsigned
NBlockCopyLen1
,
unsigned
NBlockCopyLen1
,
unsigned
NBlockCopyLen2
,
unsigned
NBlockCopyLen2
,
unsigned
NBlockCopyLen3
>
unsigned
NBlockCopyLen3
,
unsigned
BlockSize
,
unsigned
GridSize
>
__global__
void
gridwise_convolution
(
InDesc
,
__global__
void
gridwise_convolution
(
InDesc
,
TFloat
*
const
__restrict__
p_in
,
TFloat
*
const
__restrict__
p_in
,
WeiDesc
,
WeiDesc
,
...
@@ -452,8 +311,7 @@ __global__ void gridwise_convolution(InDesc,
...
@@ -452,8 +311,7 @@ __global__ void gridwise_convolution(InDesc,
__shared__
TFloat
p_wei_block
[
wei_block_size
];
__shared__
TFloat
p_wei_block
[
wei_block_size
];
__shared__
TFloat
p_out_block
[
out_block_size
];
__shared__
TFloat
p_out_block
[
out_block_size
];
const
unsigned
block_id
=
const
unsigned
block_id
=
blockIdx
.
x
;
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
blockIdx
.
z
*
(
gridDim
.
y
*
gridDim
.
x
);
unsigned
itmp
=
block_id
;
unsigned
itmp
=
block_id
;
unsigned
n_block_work_id
=
itmp
/
(
KBlockWork
*
YBlockWork
*
XBlockWork
);
unsigned
n_block_work_id
=
itmp
/
(
KBlockWork
*
YBlockWork
*
XBlockWork
);
...
@@ -515,17 +373,16 @@ __global__ void gridwise_convolution(InDesc,
...
@@ -515,17 +373,16 @@ __global__ void gridwise_convolution(InDesc,
NBlockCopyLen1
,
NBlockCopyLen1
,
NBlockCopyLen2
,
NBlockCopyLen2
,
NBlockCopyLen3
,
NBlockCopyLen3
,
decltype
(
f_copy
)
>
(
decltype
(
f_copy
)
,
in_block_glb_desc
,
BlockSize
>
(
in_block_glb_desc
,
p_in
+
in_block_glb_desc
.
Get1dIndex
(
n_block_work_begin
,
p_in
+
in_block_glb_desc
.
Get1dIndex
(
n_block_work_begin
,
c_block_work_begin
,
c_block_work_begin
,
hi_block_work_begin
,
hi_block_work_begin
,
wi_block_work_begin
),
wi_block_work_begin
),
in_block_lds_desc
,
in_block_lds_desc
,
p_in_block
,
p_in_block
,
f_copy
);
f_copy
);
#if 1
// copy weight tensor to LDS
// copy weight tensor to LDS
blockwise_4d_tensor_op
<
TFloat
,
blockwise_4d_tensor_op
<
TFloat
,
decltype
(
wei_block_glb_desc
),
decltype
(
wei_block_glb_desc
),
...
@@ -534,7 +391,8 @@ __global__ void gridwise_convolution(InDesc,
...
@@ -534,7 +391,8 @@ __global__ void gridwise_convolution(InDesc,
NBlockCopyLen1
,
NBlockCopyLen1
,
NBlockCopyLen2
,
NBlockCopyLen2
,
NBlockCopyLen3
,
NBlockCopyLen3
,
decltype
(
f_copy
)
>
(
decltype
(
f_copy
),
BlockSize
>
(
wei_block_glb_desc
,
wei_block_glb_desc
,
p_wei
+
wei_block_glb_desc
.
Get1dIndex
(
k_block_work_begin
,
c_block_work_begin
,
0
,
0
),
p_wei
+
wei_block_glb_desc
.
Get1dIndex
(
k_block_work_begin
,
c_block_work_begin
,
0
,
0
),
wei_block_lds_desc
,
wei_block_lds_desc
,
...
@@ -549,17 +407,18 @@ __global__ void gridwise_convolution(InDesc,
...
@@ -549,17 +407,18 @@ __global__ void gridwise_convolution(InDesc,
NBlockCopyLen1
,
NBlockCopyLen1
,
NBlockCopyLen2
,
NBlockCopyLen2
,
NBlockCopyLen3
,
NBlockCopyLen3
,
decltype
(
f_copy
)
>
(
decltype
(
f_copy
),
out_block_glb_desc
,
BlockSize
>
(
out_block_glb_desc
,
p_out
+
out_block_glb_desc
.
Get1dIndex
(
n_block_work_begin
,
p_out
+
k_block_work_begin
,
out_block_glb_desc
.
Get1dIndex
(
n_block_work_begin
,
ho_block_work_begin
,
k_block_work_begin
,
wo_block_work_begin
),
ho_block_work_begin
,
out_block_lds_desc
,
wo_block_work_begin
),
p_out_block
,
out_block_lds_desc
,
f_copy
);
p_out_block
,
f_copy
);
#if
0
#if
1
__syncthreads
();
__syncthreads
();
#endif
#endif
...
@@ -569,14 +428,15 @@ __global__ void gridwise_convolution(InDesc,
...
@@ -569,14 +428,15 @@ __global__ void gridwise_convolution(InDesc,
decltype
(
wei_block_lds_desc
),
decltype
(
wei_block_lds_desc
),
decltype
(
out_block_lds_desc
),
decltype
(
out_block_lds_desc
),
OutTileSizeH
,
OutTileSizeH
,
OutTileSizeW
>
(
in_block_lds_desc
,
OutTileSizeW
,
p_in_block
,
BlockSize
>
(
in_block_lds_desc
,
wei_block_lds_desc
,
p_in_block
,
p_wei_block
,
wei_block_lds_desc
,
out_block_lds_desc
,
p_wei_block
,
p_out_block
);
out_block_lds_desc
,
p_out_block
);
#if
0
#if
1
__syncthreads
();
__syncthreads
();
#endif
#endif
...
@@ -588,15 +448,15 @@ __global__ void gridwise_convolution(InDesc,
...
@@ -588,15 +448,15 @@ __global__ void gridwise_convolution(InDesc,
NBlockCopyLen1
,
NBlockCopyLen1
,
NBlockCopyLen2
,
NBlockCopyLen2
,
NBlockCopyLen3
,
NBlockCopyLen3
,
decltype
(
f_copy
)
>
(
decltype
(
f_copy
)
,
out_block_lds_desc
,
BlockSize
>
(
out_block_lds_desc
,
p_out_block
,
p_out_block
,
out_block_glb_desc
,
out_block_glb_desc
,
p_out
+
out_block_glb_desc
.
Get1dIndex
(
n_block_work_begin
,
p_out
+
k
_block_work_begin
,
out_block_glb_desc
.
Get1dIndex
(
n
_block_work_begin
,
ho
_block_work_begin
,
k
_block_work_begin
,
w
o_block_work_begin
)
,
h
o_block_work_begin
,
f_copy
);
wo_block_work_begin
),
#endif
f_copy
);
}
}
}
}
src/include/threadwise_tensor_op.cuh
0 → 100644
View file @
08c7f743
#pragma once
#include "constant_tensor_descriptor.cuh"
template
<
class
TFloat
,
class
SrcDesc
,
class
DstDesc
,
class
F
>
__device__
void
threadwise_4d_tensor_op
(
SrcDesc
,
TFloat
*
const
__restrict__
p_src
,
DstDesc
,
TFloat
*
__restrict__
p_dst
,
F
f
)
{
constexpr
auto
I0
=
Index
<
0
>
{};
constexpr
auto
I1
=
Index
<
1
>
{};
constexpr
auto
I2
=
Index
<
2
>
{};
constexpr
auto
I3
=
Index
<
3
>
{};
constexpr
auto
src_desc
=
SrcDesc
{};
constexpr
auto
dst_desc
=
DstDesc
{};
static_assert
(
is_same
<
decltype
(
src_desc
.
GetLengths
()),
decltype
(
dst_desc
.
GetLengths
())
>::
value
);
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc);
print_ConstantTensorDescriptor(dst_desc);
}
#endif
for
(
unsigned
did0
=
0
;
did0
<
src_desc
.
GetLength
(
I0
);
++
did0
)
{
for
(
unsigned
did1
=
0
;
did1
<
src_desc
.
GetLength
(
I1
);
++
did1
)
{
for
(
unsigned
did2
=
0
;
did2
<
src_desc
.
GetLength
(
I2
);
++
did2
)
{
for
(
unsigned
did3
=
0
;
did3
<
src_desc
.
GetLength
(
I3
);
++
did3
)
{
const
unsigned
sindex
=
src_desc
.
GetStride
(
I0
)
*
did0
+
src_desc
.
GetStride
(
I1
)
*
did1
+
src_desc
.
GetStride
(
I2
)
*
did2
+
src_desc
.
GetStride
(
I3
)
*
did3
;
const
unsigned
dindex
=
dst_desc
.
GetStride
(
I0
)
*
did0
+
dst_desc
.
GetStride
(
I1
)
*
did1
+
dst_desc
.
GetStride
(
I2
)
*
did2
+
dst_desc
.
GetStride
(
I3
)
*
did3
;
f
(
p_src
[
sindex
],
p_dst
[
dindex
]);
#if 0
if(threadIdx.x == 0)
{
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
}
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment