Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
6a3f3f95
Commit
6a3f3f95
authored
Apr 03, 2019
by
Jing Zhang
Browse files
add
parent
b188c0d2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
14 deletions
+17
-14
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
+3
-3
driver/driver.hip.cpp
driver/driver.hip.cpp
+1
-1
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+12
-8
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
+1
-2
No files found.
driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp
View file @
6a3f3f95
...
@@ -224,7 +224,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -224,7 +224,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
// 1x1, 14x14, Vega 20, hack CPerBlock = 1
// 1x1, 14x14, Vega 20, hack CPerBlock = 1
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
BPerBlock
=
64
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
CPerBlock
=
1
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
BPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
...
@@ -232,7 +232,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -232,7 +232,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
...
@@ -249,7 +249,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
...
@@ -249,7 +249,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
InBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
4
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
256
;
#endif
#endif
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
...
...
driver/driver.hip.cpp
View file @
6a3f3f95
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
...
@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 14x14 image, C = 2048
// 1x1 filter, 14x14 image, C = 2048
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
C
=
2048
;
...
...
src/include/blockwise_gemm.hip.hpp
View file @
6a3f3f95
...
@@ -404,10 +404,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -404,10 +404,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else
#else
int
k
=
k_begin
;
int
k
=
k_begin
;
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
512
);
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
256
);
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
ds_read_b128
(
reg_b
[
1
],
b_loc
,
128
+
k
*
256
);
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
256
+
k
*
512
);
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
lds_b_block_off
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
*
lds_a_block_off
);
lgkmcnt
(
2
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
lgkmcnt
(
1
);
...
@@ -416,12 +420,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -416,12 +420,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for
(
int
i
=
0
;
i
<
k_chunk
-
1
;
i
++
)
for
(
int
i
=
0
;
i
<
k_chunk
-
1
;
i
++
)
{
{
k
=
k
+
1
;
k
=
k
+
1
;
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
512
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
256
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
lds_b_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
128
+
k
*
256
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
256
+
k
*
512
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
*
lds_a_block_off
);
lgkmcnt
(
2
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
lgkmcnt
(
1
);
...
...
src/include/gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp
View file @
6a3f3f95
...
@@ -297,8 +297,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
...
@@ -297,8 +297,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
for
(
index_t
c_block_data_begin
=
0
;
c_block_data_begin
<
C
;
c_block_data_begin
+=
CPerBlock
,
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_in_global_block_offset
+=
CPerBlock
*
in_cb_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
p_wei_global_block_offset
+=
CPerBlock
*
wei_cyxk_global_desc
.
GetStride
(
I0
),
__syncthreads
())
__syncthreads
())
{
{
// load data
// load data
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
blockwise_in_copy
.
Run
(
p_in_global_block_offset
,
p_in_block
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment