Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d8970ea0
Commit
d8970ea0
authored
Jul 03, 2019
by
Jing Zhang
Browse files
add double reg buffer into gemm
parent
f05b210a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
0 deletions
+35
-0
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+35
-0
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
d8970ea0
...
...
@@ -145,6 +145,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
//FloatB p_b_thread[b_thread_mtx.GetElementSpace() * 2];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
...
...
@@ -173,6 +174,39 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
#if 0
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
int b_reg_0 = (k % 2) * 2;
int b_reg_1 = ((k - 1) % 2) * 2;
reg_b[b_reg_0 + 0] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
reg_b[b_reg_0 + 1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[b_reg_1 + 0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[b_reg_1 + 1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(
reg_a[1], reg_b[b_reg_1 + 0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(
reg_a[1], reg_b[b_reg_1 + 1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
}
outerProduct4x4(reg_a[0], reg_b[2], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[3], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
outerProduct4x4(reg_a[1], reg_b[2], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[3], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Float4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
...
...
@@ -191,6 +225,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
#endif
}
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment