Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
28354a0f
Commit
28354a0f
authored
Feb 12, 2019
by
Chao Liu
Browse files
make LDS double buffer works, 1x1 conv now hits 80% of peak
parent
61ac0866
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
247 additions
and
29 deletions
+247
-29
driver/conv.cu
driver/conv.cu
+1
-1
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
+29
-28
src/include/blockwise_2d_tensor_op.cuh
src/include/blockwise_2d_tensor_op.cuh
+192
-0
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh
...t_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh
+25
-0
No files found.
driver/conv.cu
View file @
28354a0f
...
...
@@ -614,7 +614,7 @@ int main()
nrepeat
);
#endif
#if
1
#if
0
if(S == 3 && R == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
...
...
driver/device_implicit_gemm_convolution_2_cnhw_csrk_knhw.cuh
View file @
28354a0f
...
...
@@ -128,7 +128,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
constexpr
unsigned
BlockSize
=
64
;
#elif 1
// 1x1, 28x28, 128 threads
// 1x1, 28x28, 128 threads, no lds-double-buffer
// 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128
constexpr
unsigned
BPerBlock
=
64
;
constexpr
unsigned
KPerBlock
=
128
;
constexpr
unsigned
CPerBlock
=
8
;
...
...
@@ -215,37 +216,37 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
cudaEventCreate
(
&
start
);
cudaEventRecord
(
start
,
0
);
#if
1
#if
0
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
#else
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_cnhw_desc
),
decltype
(
wei_csrk_desc
),
decltype
(
out_knhw_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
BPerThread
,
KPerThread
,
GemmThreadPerColumnPerCluster
,
GemmThreadPerRowPerCluster
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_cnhw_desc
),
decltype
(
wei_csrk_desc
),
decltype
(
out_knhw_desc
),
BPerBlock
,
KPerBlock
,
CPerBlock
,
BPerThread
,
KPerThread
,
GemmThreadPerColumnPerCluster
,
GemmThreadPerRowPerCluster
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
InBlockCopyThreadPerDim0
,
InBlockCopyThreadPerDim1
,
WeiBlockCopyThreadPerDim0
,
WeiBlockCopyThreadPerDim1
,
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
>
<<<
grid_dim
,
block_dim
>>>
(
in_cnhw_desc
,
static_cast
<
T
*>
(
in_cnhw_device_buf
.
GetDeviceBuffer
()),
wei_csrk_desc
,
...
...
src/include/blockwise_2d_tensor_op.cuh
View file @
28354a0f
...
...
@@ -512,4 +512,196 @@ struct Blockwise2dTensorCopy3
}
}
}
#if 1
__device__
constexpr
unsigned
GetRegisterClipboardSize
()
const
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
return
DataPerRead
*
(
L0
+
thread_per_d0
-
1
)
/
thread_per_d0
;
}
__device__
void
RunLoadRegisterClipboard
(
const
Float
*
__restrict__
p_src
,
Float
*
p_clipboard
)
const
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
using
Float2
=
float2
;
using
Float4
=
float4
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
unsigned
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
unsigned
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
for
(
unsigned
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
if
(
DataPerRead
==
1
)
{
p_clipboard
[
iloop
]
=
p_src
[
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_clipboard
+
iloop
*
2
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_clipboard
+
iloop
*
4
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
mSrcMyThreadOffset
+
iloop
*
src_loop_stride
));
}
else
{
assert
(
false
);
}
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
unsigned
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
if
(
DataPerRead
==
1
)
{
p_clipboard
[
nloop_d0
]
=
p_src
[
mSrcMyThreadOffset
+
nloop_d0
*
src_loop_stride
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_clipboard
+
nloop_d0
*
2
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_src
+
mSrcMyThreadOffset
+
nloop_d0
*
src_loop_stride
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_clipboard
+
nloop_d0
*
4
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_src
+
mSrcMyThreadOffset
+
nloop_d0
*
src_loop_stride
));
}
else
{
assert
(
false
);
}
}
}
}
__device__
void
RunStoreRegisterClipboard
(
const
Float
*
__restrict__
p_clipboard
,
Float
*
__restrict__
p_dst
)
const
{
static_assert
(
is_same
<
Float
,
float
>::
value
,
"wrong! only support float!
\n
"
);
using
Float2
=
float2
;
using
Float4
=
float4
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
unsigned
L0
=
CopyLengths
{}.
Get
(
I0
);
constexpr
unsigned
L1
=
CopyLengths
{}.
Get
(
I1
);
constexpr
unsigned
thread_per_d1
=
(
L1
+
DataPerRead
-
1
)
/
DataPerRead
;
constexpr
unsigned
thread_per_d0
=
BlockSize
/
thread_per_d1
;
constexpr
unsigned
num_active_thread
=
thread_per_d0
*
thread_per_d1
;
if
(
BlockSize
>
num_active_thread
)
{
if
(
get_thread_local_1d_id
()
>=
num_active_thread
)
{
return
;
}
}
constexpr
unsigned
nloop_d0
=
L0
/
thread_per_d0
;
constexpr
unsigned
src_loop_stride
=
SrcDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
constexpr
unsigned
dst_loop_stride
=
DstDesc
{}.
GetStride
(
I0
)
*
thread_per_d0
;
for
(
unsigned
iloop
=
0
;
iloop
<
nloop_d0
;
++
iloop
)
{
if
(
DataPerRead
==
1
)
{
p_dst
[
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
]
=
p_clipboard
[
iloop
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_clipboard
+
iloop
*
2
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
mDstMyThreadOffset
+
iloop
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_clipboard
+
iloop
*
4
));
}
else
{
assert
(
false
);
}
}
constexpr
bool
has_tail_d0
=
(
L0
>
nloop_d0
*
thread_per_d0
);
if
(
has_tail_d0
)
{
constexpr
unsigned
tail_d0
=
L0
-
nloop_d0
*
thread_per_d0
;
if
(
get_thread_local_1d_id
()
<
tail_d0
*
thread_per_d1
)
{
if
(
DataPerRead
==
1
)
{
p_dst
[
mDstMyThreadOffset
+
nloop_d0
*
dst_loop_stride
]
=
p_clipboard
[
nloop_d0
];
}
else
if
(
DataPerRead
==
2
)
{
*
(
reinterpret_cast
<
Float2
*>
(
p_dst
+
mDstMyThreadOffset
+
nloop_d0
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
Float2
*>
(
p_clipboard
+
nloop_d0
*
2
));
}
else
if
(
DataPerRead
==
4
)
{
*
(
reinterpret_cast
<
Float4
*>
(
p_dst
+
mDstMyThreadOffset
+
nloop_d0
*
dst_loop_stride
))
=
*
(
reinterpret_cast
<
const
Float4
*>
(
p_clipboard
+
nloop_d0
*
4
));
}
else
{
assert
(
false
);
}
}
}
}
#endif
};
src/include/gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_buffer.cuh
View file @
28354a0f
...
...
@@ -262,8 +262,26 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
__syncthreads
();
// load next data
#if 0
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
#elif
0
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block_next
);
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
#elif 1
Float
p_in_register_clipboard
[
blockwise_in_copy
.
GetRegisterClipboardSize
()];
Float
p_wei_register_clipboard
[
blockwise_wei_copy
.
GetRegisterClipboardSize
()];
blockwise_in_copy
.
RunLoadRegisterClipboard
(
p_in_global_block_offset
,
p_in_register_clipboard
);
blockwise_wei_copy
.
RunLoadRegisterClipboard
(
p_wei_global_block_offset
,
p_wei_register_clipboard
);
#endif
// compute on current data
// a series of GEMM
...
...
@@ -283,6 +301,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_double_b
f_accum
);
}
}
#if 0
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block_next);
#elif
1
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreRegisterClipboard
(
p_wei_register_clipboard
,
p_wei_block_next
);
#endif
}
// last computation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment