Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
62ebdfde
"docs/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "cf77e9b525e3a0f5b844387b73284df1a72c1ee6"
Commit
62ebdfde
authored
Aug 17, 2021
by
Jing Zhang
Browse files
clean xdlops_gemm
parent
cb35d6fc
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
266 additions
and
281 deletions
+266
-281
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+8
-35
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+257
-245
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
62ebdfde
...
@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
KPerBlock
=
K0
;
static
constexpr
index_t
KPerBlock
=
K0
;
static
constexpr
index_t
KPack
=
K1
;
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
...
@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
laneId
=
wave_idx
[
I2
];
const
auto
blk
_idx
=
xdlops_gemm
.
GetBlkId
x
();
const
auto
xdlops_a
_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataInde
x
();
const
auto
blk_id
=
blk_idx
[
I0
];
return
make_tuple
(
xdlops_a_idx
[
I0
],
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
0
);
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
return
make_tuple
(
blk_id
,
0
,
waveId_m
,
blk_td
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
waveId_m
,
laneId
,
0
);
}
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
laneId
=
wave_idx
[
I2
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
return
make_tuple
(
xdlops_b_idx
[
I0
],
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
0
);
{
return
make_tuple
(
blk_id
,
0
,
waveId_n
,
blk_td
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
waveId_n
,
laneId
,
0
);
}
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
...
@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
KPerBlock
%
xdlops_gemm
.
KPerXdlops
==
0
,
"KPerBlock is wrong!"
);
static_assert
(
K1
%
xdlops_gemm
.
mfma_type
.
k_base
==
0
,
"K1 is wrong!"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
k
0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// read B
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
make_tuple
(
k
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
k
0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_
base
>::
type
;
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_
per_blk
>::
type
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
62ebdfde
This diff is collapsed.
Click to expand it.
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
62ebdfde
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
//
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment