Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
1c4ef23c
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "216e3da60959ee5968d7424ac0943c86fbf55375"
Commit
1c4ef23c
authored
Aug 09, 2019
by
Chao Liu
Browse files
cleaning up
parent
4908fe3f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
21 deletions
+22
-21
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+12
-11
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+9
-9
driver/src/driver.cpp
driver/src/driver.cpp
+1
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
1c4ef23c
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
...
@@ -23,8 +23,7 @@ template <index_t GridSize,
...
@@ -23,8 +23,7 @@ template <index_t GridSize,
index_t
BPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
EPerBlock
,
index_t
N1
,
index_t
GemmNRepeat
,
index_t
N2
,
index_t
GemmMPerThreadSubC
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmMLevel0Cluster
,
...
@@ -47,17 +46,19 @@ template <index_t GridSize,
...
@@ -47,17 +46,19 @@ template <index_t GridSize,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopySrcAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
class
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
WeiBlockCopyDstDataPerWrite_K
>
index_t
OutThreadCopyDataPerAccess_W
>
struct
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
struct
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
__device__
void
__launch_bounds__
(
BlockSize
,
2
)
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
Float
*
const
__restrict__
p_out_global
)
const
{
{
// this is a mess
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert
(
N2
==
GemmNPerThreadSubC
,
"wrong!"
);
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
static_assert
((
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
0
,
...
@@ -464,4 +465,4 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -464,4 +465,4 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
};
};
}
// namespace ck
}
// namespace ck
#endif
#endif
// CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
1c4ef23c
...
@@ -54,11 +54,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -54,11 +54,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
N1
=
2
;
constexpr
index_t
N2
=
4
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
#if 1
#if 1
// each thread hold 64 data
// each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -67,6 +62,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -67,6 +62,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
...
@@ -168,6 +165,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -168,6 +165,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#endif
#endif
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
...
@@ -192,8 +194,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -192,8 +194,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
BPerBlock
,
BPerBlock
,
KPerBlock
,
KPerBlock
,
EPerBlock
,
EPerBlock
,
N1
,
GemmNRepeat
,
N2
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
...
@@ -216,8 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -216,8 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
WeiBlockCopyDstDataPerWrite_K
>
{};
OutThreadCopyDataPerAccess_W
>
{};
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
...
...
driver/src/driver.cpp
View file @
1c4ef23c
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
...
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment