Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
18a81e35
Commit
18a81e35
authored
Mar 22, 2019
by
Chao Liu
Browse files
adding assembly
parent
8c923db4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
11 deletions
+102
-11
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+102
-11
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
18a81e35
...
@@ -435,11 +435,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -435,11 +435,12 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
#pragma unroll
#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
{
// read first batch of A, B
// read first batch of A, B
// copy A-sub to form A
// copy A-sub to form A
#pragma unroll
//
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
{
#if 0
threadwise_matrix_copy(
threadwise_matrix_copy(
a_block_mtx,
a_block_mtx,
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
...
@@ -447,12 +448,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -447,12 +448,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_thread_mtx,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths());
#else
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetA
];
}
}
#endif
}
}
// copy B-sub to form B
// copy B-sub to form B
#pragma unroll
//
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
{
#if 0
threadwise_matrix_copy(
threadwise_matrix_copy(
b_block_mtx,
b_block_mtx,
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
...
@@ -460,13 +474,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -460,13 +474,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_thread_mtx,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths());
#else
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetB
];
}
}
#endif
}
}
// loop over batch
// loop over batch
#pragma unroll
//
#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
{
// do current batch of gemm
// do current batch of gemm
#if 0
threadwise_gemm(a_thread_mtx,
threadwise_gemm(a_thread_mtx,
True,
True,
p_a_thread,
p_a_thread,
...
@@ -477,13 +504,32 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -477,13 +504,32 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
False,
False,
p_c_thread + ib * ThreadMatrixStrideC,
p_c_thread + ib * ThreadMatrixStrideC,
f_accum);
f_accum);
#else
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
c_thread_mtx
.
NCol
();
++
j
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
j
);
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
j
)
+
ib
*
ThreadMatrixStrideC
;
f_accum
(
p_c_thread
[
cindex
],
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
]);
}
}
}
#endif
// read next batch of a, b
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
if
(
BlockMatrixStrideA
!=
0
)
{
{
#pragma unroll
//
#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
{
#if 0
threadwise_matrix_copy(
threadwise_matrix_copy(
a_block_mtx,
a_block_mtx,
p_a_block +
p_a_block +
...
@@ -492,14 +538,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -492,14 +538,28 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
a_thread_mtx,
a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths());
a_thread_sub_mtx.GetLengths());
#else
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
];
}
}
#endif
}
}
}
}
if
(
BlockMatrixStrideB
!=
0
)
if
(
BlockMatrixStrideB
!=
0
)
{
{
#pragma unroll
//
#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
{
#if 0
threadwise_matrix_copy(
threadwise_matrix_copy(
b_block_mtx,
b_block_mtx,
p_b_block +
p_b_block +
...
@@ -508,11 +568,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -508,11 +568,25 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
b_thread_mtx,
b_thread_mtx,
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths());
b_thread_sub_mtx.GetLengths());
#else
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
];
}
}
#endif
}
}
}
}
}
}
// do last batch of gemm
// do last batch of gemm
#if 0
threadwise_gemm(a_thread_mtx,
threadwise_gemm(a_thread_mtx,
True,
True,
p_a_thread,
p_a_thread,
...
@@ -523,6 +597,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -523,6 +597,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
False,
False,
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC,
p_c_thread + (BatchPerThread - 1) * ThreadMatrixStrideC,
f_accum);
f_accum);
#else
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
c_thread_mtx
.
NCol
();
++
j
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
j
);
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
j
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
f_accum
(
p_c_thread
[
cindex
],
p_a_thread
[
aindex
]
*
p_b_thread
[
bindex
]);
}
}
}
#endif
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment