Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
66edb259
Commit
66edb259
authored
Apr 04, 2019
by
Chao Liu
Browse files
Merge branch 'inline_asm_v2' of github.com:asroy/modular_convolution into inline_asm_v2
parents
19b41797
62c4d5df
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
48 deletions
+24
-48
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+24
-48
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
66edb259
...
...
@@ -377,56 +377,34 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
// auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
Float4
*
reg_a
=
(
Float4
*
)(
p_a_thread
);
Float4
*
reg_b
=
(
Float4
*
)(
p_b_thread
);
Float4
*
reg_c
=
(
Float4
*
)(
p_c_thread
);
void
*
a_loc
=
(
void
*
)(
p_a_block
+
mMyThreadOffsetA
);
void
*
b_loc
=
(
void
*
)(
p_b_block
+
mMyThreadOffsetB
);
// loop over k
int
k_chunk
=
K
;
// for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk)
index_t
k_begin
=
0
;
{
#if 0
ds_read_b128(reg_a[0], a_loc, 0);
ds_read_b128(reg_a[1], a_loc, 256);
ds_read_b128(reg_b[0], b_loc, 0);
ds_read_b128(reg_b[1], b_loc, 128);
lgkmcnt(0);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else
int
k
=
k_begin
;
int
lds_a_block_off
=
sizeof
(
Float
)
*
M
;
int
lds_b_block_off
=
sizeof
(
Float
)
*
N
;
int
lds_a_block_off_1
=
MPerLevel1Cluster
*
sizeof
(
Float
);
int
lds_b_block_off_1
=
NPerLevel1Cluster
*
sizeof
(
Float
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
lds_b_block_off
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
*
lds_a_block_off
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
0
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
0
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
lgkmcnt
(
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
k_chunk
-
1
;
i
++
)
for
(
int
k_
i
=
1
;
k_i
<
K
;
k_
i
++
)
{
k
=
k
+
1
;
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k
*
lds_a_block_off
);
ds_read_b128
(
reg_a
[
0
],
a_loc
,
k_i
*
lds_a_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
*
lds_b_block_off
);
ds_read_b128
(
reg_b
[
0
],
b_loc
,
k
_i
*
lds_b_block_off
);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
*
lds_a_block_off
);
ds_read_b128
(
reg_b
[
1
],
b_loc
,
lds_b_block_off_1
+
k
_i
*
lds_b_block_off
);
ds_read_b128
(
reg_a
[
1
],
a_loc
,
lds_a_block_off_1
+
k
_i
*
lds_a_block_off
);
lgkmcnt
(
2
);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
lgkmcnt
(
1
);
...
...
@@ -435,8 +413,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
#endif
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment