Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
22114959
"cuda/git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "11369d67d06f7098d67a3be2057fa1bcc2fcfc8b"
Commit
22114959
authored
Mar 22, 2019
by
Chao Liu
Browse files
adding inline asm for 16x4 gemm
parent
68ae0731
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
232 additions
and
6 deletions
+232
-6
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+225
-1
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
...idwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
+7
-5
No files found.
src/include/blockwise_gemm.hip.hpp
View file @
22114959
...
...
@@ -526,7 +526,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
// this version put copy and compute in same place, experimenting with compiler behaviour
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v2
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
...
...
@@ -687,6 +686,231 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
}
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
,
class
Accumulator
>
__device__
void
Run_v3
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
,
Accumulator
f_accum
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
False
=
integral_constant
<
bool
,
false
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
unsigned
KPerBlock
=
a_block_mtx
.
NRow
();
// A is transposed
constexpr
unsigned
MPerThread
=
c_thread_mtx
.
NRow
();
constexpr
unsigned
NPerThread
=
c_thread_mtx
.
NCol
();
// thread A, B for GEMM
// A is transposed, b is not
constexpr
auto
a_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub for copy
constexpr
auto
a_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_sub_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
unsigned
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
unsigned
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
constexpr
unsigned
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
unsigned
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// loop over k
//#pragma unroll
for
(
unsigned
k_begin
=
0
;
k_begin
<
KPerBlock
;
k_begin
+=
KPerThreadLoop
)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetA
];
}
}
}
// copy B-sub to form B
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
mMyThreadOffsetB
];
}
}
}
// loop over batch
//#pragma unroll
for
(
unsigned
ib
=
0
;
ib
+
1
<
BatchPerThread
;
++
ib
)
{
// do current batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
);
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
// read next batch of a, b
if
(
BlockMatrixStrideA
!=
0
)
{
//#pragma unroll
for
(
unsigned
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
a_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
a_thread_mtx
.
NCol
();
++
j
)
{
p_a_thread
[
a_thread_mtx
.
Get1dIndex
(
i
,
m_repeat
*
MPerThreadSubC
+
j
)]
=
p_a_block
[
a_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
m_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideA
+
mMyThreadOffsetA
];
}
}
}
}
if
(
BlockMatrixStrideB
!=
0
)
{
//#pragma unroll
for
(
unsigned
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
for
(
unsigned
i
=
0
;
i
<
b_thread_mtx
.
NRow
();
++
i
)
{
for
(
unsigned
j
=
0
;
j
<
b_thread_mtx
.
NCol
();
++
j
)
{
p_b_thread
[
b_thread_mtx
.
Get1dIndex
(
i
,
n_repeat
*
NPerThreadSubC
+
j
)]
=
p_b_block
[
b_block_mtx
.
Get1dIndex
(
k_begin
+
i
,
n_repeat
*
MPerLevel1Cluster
+
j
)
+
(
ib
+
1
)
*
BlockMatrixStrideB
+
mMyThreadOffsetB
];
}
}
}
}
}
// do last batch of gemm
for
(
unsigned
k
=
0
;
k
<
a_thread_mtx
.
NRow
();
++
k
)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif
1
static_assert
(
c_thread_mtx
.
NRow
()
==
16
&&
c_thread_mtx
.
NCol
()
==
4
,
"asm is only for 16x4"
);
const
unsigned
bindex
=
b_thread_mtx
.
Get1dIndex
(
k
,
0
);
for
(
unsigned
i
=
0
;
i
<
c_thread_mtx
.
NRow
();
++
i
)
{
const
unsigned
aindex
=
a_thread_mtx
.
Get1dIndex
(
k
,
i
);
// A is transposed
const
unsigned
cindex
=
c_thread_mtx
.
Get1dIndex
(
i
,
0
)
+
(
BatchPerThread
-
1
)
*
ThreadMatrixStrideC
;
asm
volatile
(
"
\n
\
v_mac_f32 %0, %4, %5
\n
\
v_mac_f32 %1, %4, %6
\n
\
v_mac_f32 %2, %4, %7
\n
\
v_mac_f32 %3, %4, %8
\n
\
"
:
"=v"
(
p_c_thread
[
cindex
+
0
]),
"=v"
(
p_c_thread
[
cindex
+
1
]),
"=v"
(
p_c_thread
[
cindex
+
2
]),
"=v"
(
p_c_thread
[
cindex
+
3
])
:
"v"
(
p_a_thread
[
aindex
]),
"v"
(
p_b_thread
[
bindex
+
0
]),
"v"
(
p_b_thread
[
bindex
+
1
]),
"v"
(
p_b_thread
[
bindex
+
2
]),
"v"
(
p_b_thread
[
bindex
+
3
]),
"0"
(
p_c_thread
[
cindex
+
0
]),
"1"
(
p_c_thread
[
cindex
+
1
]),
"2"
(
p_c_thread
[
cindex
+
2
]),
"3"
(
p_c_thread
[
cindex
+
3
]));
}
#endif
}
}
}
template
<
class
BlockMatrixC
,
unsigned
BlockMatrixStrideC
,
class
FloatC
>
__device__
void
CopyThreadMatrixCToBlockMatrixC
(
const
FloatC
*
__restrict__
p_c_thread
,
FloatC
*
__restrict__
p_c_block
)
const
...
...
src/include/gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn.hip.hpp
View file @
22114959
...
...
@@ -209,15 +209,17 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
{
for
(
unsigned
x
=
0
;
x
<
X
;
++
x
)
{
#if
1
#if
0
blockwise_batch_gemm.Run
#elif
0
blockwise_batch_gemm
.
Run_v2
#elif 1
blockwise_batch_gemm
.
Run_v3
#endif
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
});
(
p_wei_block
+
wei_cyxk_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_in_block
+
in_chwn_block_desc
.
Get1dIndex
(
0
,
y
,
x
,
0
),
p_out_thread
,
[](
auto
&
acc
,
const
auto
&&
v
)
{
acc
+=
v
;
});
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment