Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
66edb259
Commit
66edb259
authored
Apr 04, 2019
by
Chao Liu
Browse files
Merge branch 'inline_asm_v2' of github.com:asroy/modular_convolution into inline_asm_v2
parents
19b41797
62c4d5df
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
48 deletions
+24
-48
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+24
-48
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
66edb259
...
...
@@ -361,10 +361,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
float
p_thread
[
a_thread_mtx
.
GetElementSpace
()
+
b_thread_mtx
.
GetElementSpace
()];
...
...
@@ -377,66 +377,42 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
// auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
// loop over k
int
k_chunk
=
K
;
// for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
index_t
k_begin
=
0
;
{
#if 0
ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_a[1], a_loc, 256);
ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[1], b_loc, 128);
lgkmcnt(0);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
#pragma unroll
for
(
int
k_i
=
1
;
k_i
<
K
;
k_i
++
)
{
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k_i
*
lds_a_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k_i
*
lds_b_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
#else
int
k
=
k_begin
;
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
lds_b_block_off
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
*
lds_a_block_off
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k_i
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k_i
*
lds_a_block_off
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
k_chunk
-
1
;
i
++
)
{
k
=
k
+
1
;
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
lds_b_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
*
lds_a_block_off
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
#endif
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment