Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
5696c81f
Commit
5696c81f
authored
Apr 10, 2019
by
Chao Liu
Browse files
simplify blockwise batched GEMM
parent
71434918
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
56 deletions
+16
-56
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+16
-56
No files found.
src/include/blockwise_batched_gemm.hip.hpp
View file @
5696c81f
...
...
@@ -211,52 +211,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
threadwise_matrix_copy
(
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadA
>
{});
}
// copy B-sub to form B
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
threadwise_matrix_copy
(
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
// loop over batch
#pragma unroll
for
(
index_t
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
for
(
index_t
ib
=
0
;
ib
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
b_thread_mtx
,
False
,
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
ib
*
ThreadMatrixStrideC
);
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
if
(
BlockMatrixStrideA
!=
0
or
ib
==
0
)
{
#pragma unroll
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
...
...
@@ -265,7 +225,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_block_mtx
,
p_a_block
+
a_block_mtx
.
Get1dIndex
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
ib
*
BlockMatrixStrideA
+
mMyThreadOffsetA
,
a_thread_mtx
,
p_a_thread
+
a_thread_mtx
.
Get1dIndex
(
0
,
m_repeat
*
MPerThreadSubC
),
a_thread_sub_mtx
.
GetLengths
(),
...
...
@@ -273,7 +233,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
if
(
BlockMatrixStrideB
!=
0
)
if
(
BlockMatrixStrideB
!=
0
or
ib
==
0
)
{
#pragma unroll
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
...
...
@@ -282,16 +242,14 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_block_mtx
,
p_b_block
+
b_block_mtx
.
Get1dIndex
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
ib
*
BlockMatrixStrideB
+
mMyThreadOffsetB
,
b_thread_mtx
,
p_b_thread
+
b_thread_mtx
.
Get1dIndex
(
0
,
n_repeat
*
NPerThreadSubC
),
b_thread_sub_mtx
.
GetLengths
(),
Number
<
DataPerReadB
>
{});
}
}
}
// do last batch of gemm
threadwise_gemm
(
a_thread_mtx
,
True
,
p_a_thread
,
...
...
@@ -300,7 +258,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
p_b_thread
,
c_thread_mtx
,
False
,
p_c_thread
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
);
p_c_thread
+
ib
*
ThreadMatrixStrideC
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment