Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
249c5d6d
Commit
249c5d6d
authored
Jan 15, 2020
by
Chao Liu
Browse files
nvidia build
parent
ea8aa63f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
88 additions
and
103 deletions
+88
-103
driver/include/device.hpp
driver/include/device.hpp
+4
-1
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+18
-16
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
+18
-15
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+18
-15
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
+18
-45
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+9
-8
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+3
-3
No files found.
driver/include/device.hpp
View file @
249c5d6d
...
...
@@ -76,7 +76,10 @@ void launch_kernel(F kernel,
cudaStream_t
stream_id
,
Args
...
args
)
{
cudaLaunchKernel
(
f
,
grid_dim
,
block_dim
,
p_args
,
lds_byte
,
stream_id
);
const
void
*
f
=
reinterpret_cast
<
const
void
*>
(
kernel
);
void
*
p_args
[]
=
{
&
args
...};
cudaError_t
error
=
cudaLaunchKernel
(
f
,
grid_dim
,
block_dim
,
p_args
,
lds_byte
,
stream_id
);
}
template
<
typename
...
Args
,
typename
F
>
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
249c5d6d
...
...
@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
...
...
@@ -121,22 +125,18 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
...
...
@@ -147,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp
View file @
249c5d6d
...
...
@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
...
...
@@ -129,21 +133,18 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
...
...
@@ -154,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
249c5d6d
...
...
@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
...
...
@@ -217,21 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
...
...
@@ -242,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp
View file @
249c5d6d
...
...
@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
...
...
@@ -84,36 +88,6 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmM
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
...
...
@@ -186,21 +160,18 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
float
time
=
launch_and_time_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
gridwise_conv
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
...
...
@@ -211,3 +182,5 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
249c5d6d
...
...
@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace
launcher
{
using
namespace
ck
;
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
...
...
@@ -25,8 +29,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
InRightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
...
...
@@ -207,12 +209,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
0
,
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
});
});
...
...
@@ -229,3 +228,5 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
}
// namespace launcher
driver/src/conv_bwd_data_driver.cpp
View file @
249c5d6d
...
...
@@ -21,12 +21,12 @@
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
using
namespace
launcher
;
#if 1
// 3x3 filter, 2x2 stride, 35x35 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1
024
;
constexpr
index_t
C
=
1
28
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
128
;
...
...
@@ -253,7 +253,7 @@ int main(int argc, char* argv[])
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v
2
r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v
3
r1_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment