Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
62ebdfde
"...resnet50_tensorflow.git" did not exist on "8dace71ec50b435217bff5b99f0823ea7ae03ec3"
Commit
62ebdfde
authored
Aug 17, 2021
by
Jing Zhang
Browse files
clean xdlops_gemm
parent
cb35d6fc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
266 additions
and
281 deletions
+266
-281
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+8
-35
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+257
-245
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
62ebdfde
...
...
@@ -32,7 +32,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
KPerBlock
=
K0
;
static
constexpr
index_t
KPack
=
K1
;
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
...
...
@@ -66,21 +65,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
laneId
=
wave_idx
[
I2
];
const
auto
blk
_idx
=
xdlops_gemm
.
GetBlkId
x
();
const
auto
xdlops_a
_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataInde
x
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
return
make_tuple
(
blk_id
,
0
,
waveId_m
,
blk_td
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
waveId_m
,
laneId
,
0
);
}
return
make_tuple
(
xdlops_a_idx
[
I0
],
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
0
);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
...
@@ -88,21 +76,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
laneId
=
wave_idx
[
I2
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
return
make_tuple
(
blk_id
,
0
,
waveId_n
,
blk_td
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
waveId_n
,
laneId
,
0
);
}
return
make_tuple
(
xdlops_b_idx
[
I0
],
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
0
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
...
...
@@ -145,10 +122,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
KPerBlock
%
xdlops_gemm
.
KPerXdlops
==
0
,
"KPerBlock is wrong!"
);
static_assert
(
K1
%
xdlops_gemm
.
mfma_type
.
k_base
==
0
,
"K1 is wrong!"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
...
...
@@ -234,10 +207,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type
<
FloatAB
,
K1
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPerBlock
,
xdlops_gemm
.
KPerXdlops
>
{}([
&
](
auto
k
0
)
{
// read A
a_thread_copy_
.
Run
(
a_k0_m0_m1_m2_k1_block_desc
,
make_tuple
(
k
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
k
0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -245,14 +218,14 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// read B
b_thread_copy_
.
Run
(
b_k0_n0_n1_n2_k1_block_desc
,
make_tuple
(
k
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
k
0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_
base
>::
type
;
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
mfma_type
.
k_
per_blk
>::
type
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
62ebdfde
...
...
@@ -9,19 +9,16 @@ namespace ck {
enum
struct
mfma_instr
{
/// fp32
mfma_f32_32x32x1xf32
=
0
,
mfma_f32_16x16x1xf32
,
mfma_f32_4x4x1xf32
,
mfma_f32_32x32x2xf32
,
// k reduction
mfma_f32_16x16x4xf32
,
// k reduction
/// fp16
mfma_f32_32x32x4f16
,
mfma_f32_16x16x4f16
,
mfma_f32_4x4x4f16
,
mfma_f32_32x32x8f16
,
// k reduction
mfma_f32_16x16x16f16
,
// k reduction
/// bfp16
mfma_f32_32x32x2bf16
,
mfma_f32_16x16x2bf16
,
mfma_f32_4x4x2bf16
,
...
...
@@ -35,19 +32,17 @@ struct mfma_info;
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -64,19 +59,17 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
2
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -93,19 +86,17 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -122,19 +113,17 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -152,19 +141,17 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
1
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
4
;
static
constexpr
index_t
n_per_blk
=
64
;
static
constexpr
index_t
k_per_blk
=
1
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -181,19 +168,17 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -210,19 +195,17 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
4
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
32
;
static
constexpr
index_t
n
=
32
;
static
constexpr
index_t
k
=
8
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -239,19 +222,17 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
16
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -268,19 +249,17 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
num_regs_xdlops
=
num_regs_blk
*
num_output_blks
;
static
constexpr
index_t
m
=
16
;
static
constexpr
index_t
n
=
16
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
32
;
static
constexpr
index_t
k_base
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -297,19 +276,17 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
template
<
>
struct
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_blk
=
1
;
static
constexpr
index_t
num_regs_blk
=
group_size
*
num_groups_blk
;
static
constexpr
index_t
num_threads_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
num_regs_xdlops
=
4
;
static
constexpr
index_t
m
=
4
;
static
constexpr
index_t
n
=
64
;
static
constexpr
index_t
k
=
4
;
static
constexpr
index_t
cycles
=
8
;
static
constexpr
index_t
k_base
=
4
;
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_threads_per_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
4
;
static
constexpr
index_t
n_per_blk
=
64
;
static
constexpr
index_t
k_per_blk
=
4
;
static
constexpr
bool
is_k_reduction
=
false
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
...
...
@@ -327,19 +304,17 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
...
...
@@ -361,19 +336,17 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
...
...
@@ -394,19 +367,17 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
...
...
@@ -427,19 +398,17 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
...
...
@@ -460,19 +429,17 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4;
static constexpr index_t n_per_blk = 64;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
...
...
@@ -505,14 +472,9 @@ struct xdlops_info
return
true
;
}
static
constexpr
bool
IsKReduction
()
{
return
(
mfma_type
.
num_output_blks
==
1
)
&&
(
mfma_type
.
num_input_blks
>
1
);
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
IsKR
eduction
()
?
mfma_type
.
num_input_blks
:
1
;
return
mfma_type
.
is_k_r
eduction
?
mfma_type
.
num_input_blks
:
1
;
}
static
constexpr
index_t
GetNumCRegs
()
{
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
...
...
@@ -521,6 +483,13 @@ struct xdlops_info
template
<
class
base_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPack
>
struct
XdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
class
base_type_
=
base_type
,
index_t
MPerWave_
=
MPerWave
,
index_t
NPerWave_
=
NPerWave
>
...
...
@@ -684,7 +653,30 @@ struct XdlopsGemm
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m
*
mfma_type
.
n
*
mfma_type
.
num_output_blks
);
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m_per_blk
*
mfma_type
.
n_per_blk
*
mfma_type
.
num_output_blks
);
}
__host__
__device__
static
void
mfma_info_check
()
{
static_assert
(
mfma_type
.
num_threads_per_blk
==
mfma_type
.
n_per_blk
,
"n_per_blk != num_threads_per_blk"
);
static_assert
(
mfma_type
.
num_regs_per_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m_per_blk
,
"m_per_blk != num_input_blks * num_regs_per_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_per_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m_per_blk
*
mfma_type
.
n_per_blk
,
"num_regs_per_blk incorrect"
);
static_assert
(
mfma_type
.
is_k_reduction
||
(
mfma_type
.
num_input_blks
==
mfma_type
.
num_output_blks
),
"is_k_reduction wrong!"
);
}
__host__
__device__
constexpr
XdlopsGemm
()
...
...
@@ -696,30 +688,12 @@ struct XdlopsGemm
static_assert
(
MPerXdlops
==
4
||
MPerXdlops
==
8
||
MPerXdlops
==
16
||
MPerXdlops
==
32
||
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
static_assert
(
mfma_type
.
num_output_blks
==
mfma_type
.
num_input_blks
||
mfma_type
.
num_output_blks
==
1
,
"incorrect num_output_blks"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k % kbase != 0!"
);
}
template
<
typename
CM0N0M1N1M2N2GridDesc
>
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CM0N0M1N1M2N2GridDesc
&
c_m0_n0_m1_n1_m2_n2_grid_desc
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
constexpr
auto
I5
=
Number
<
5
>
{};
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I2
);
...
...
@@ -727,9 +701,9 @@ struct XdlopsGemm
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_n2_grid_desc
.
GetLength
(
I5
);
static_assert
(
N2
==
mfma_type
.
num_threads_blk
,
""
);
static_assert
(
N2
==
mfma_type
.
num_threads_
per_
blk
,
""
);
static_assert
(
M2
==
(
mfma_type
.
num_groups_blk
*
mfma_type
.
num_output_blks
*
mfma_type
.
group_size
),
M2
==
(
mfma_type
.
num_groups_
per_
blk
*
mfma_type
.
num_output_blks
*
mfma_type
.
group_size
),
""
);
return
transform_dynamic_tensor_descriptor
(
...
...
@@ -738,10 +712,10 @@ struct XdlopsGemm
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
mfma_type
.
num_groups_blk
,
make_unmerge_transform
(
make_tuple
(
mfma_type
.
num_groups_
per_
blk
,
mfma_type
.
num_input_blks
,
mfma_type
.
group_size
)),
make_pass_through_transform
(
mfma_type
.
num_threads_blk
)),
make_pass_through_transform
(
mfma_type
.
num_threads_
per_
blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -768,40 +742,79 @@ struct XdlopsGemm
is_same
<
base_type
,
ushort
>::
value
,
"base base_type must be float, half, ushort!"
);
static_assert
(
KPack
%
mfma_type
.
k_
base
==
0
,
"KPack cannot be divided by k_
base
"
);
static_assert
(
KPack
%
mfma_type
.
k_
per_blk
==
0
,
"KPack cannot be divided by k_
per_blk
"
);
static_for
<
0
,
KPack
/
mfma_type
.
k_
base
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_type
.
k_
per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
,
c_offset
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
}
static
constexpr
auto
GetBlkIdx
()
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
}
__device__
static
auto
GetBlkIdx
(
const
index_t
laneId
)
{
const
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
mfma_type
.
num_input_blks
,
mfma_type
.
num_threads_blk
))),
make_tuple
(
1
,
mfma_type
.
num_input_blks
,
mfma_type
.
num_threads_
per_
blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()
));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
laneId
));
const
auto
blk_id
=
blk_idx
[
Number
<
1
>
{}
];
const
auto
blk_td
=
blk_idx
[
Number
<
2
>
{}
];
const
auto
blk_id
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I2
];
return
make_tuple
(
blk_id
,
blk_td
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
(
laneId
);
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
mfma_type
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
(
laneId
);
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
mfma_type
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
(
laneId
);
const
auto
blk_id
=
blk_idx
[
Number
<
0
>
{}
];
const
auto
blk_td
=
blk_idx
[
Number
<
1
>
{}
];
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
index_t
n_offset
=
blk_i
*
mfma_type
.
n
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
mfma_type
.
m
+
blk_id
*
mfma_type
.
group_size
;
index_t
n_offset
=
blk_i
*
mfma_type
.
n
_per_blk
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
mfma_type
.
m
_per_blk
+
blk_id
*
mfma_type
.
group_size
;
return
CIndex
{
m_offset
,
n_offset
};
}
...
...
@@ -810,26 +823,25 @@ struct XdlopsGemm
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
struct
CLayout
{
__host__
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_blk
;
}
__host__
__device__
static
constexpr
index_t
M1
()
{
return
mfma_type
.
num_groups_
per_
blk
;
}
__host__
__device__
static
constexpr
index_t
M0
()
{
return
mfma_type
.
group_size
;
}
__host__
__device__
static
constexpr
index_t
N1
()
{
return
mfma_type
.
num_input_blks
;
}
__host__
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_blk
;
}
__host__
__device__
static
constexpr
index_t
N0
()
{
return
mfma_type
.
num_threads_
per_
blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
mfma_type
.
num_regs_
per_
blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_type
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerXdlops
*
NPerXdlops
/
(
mfma_type
.
m
*
mfma_type
.
n
*
mfma_type
.
num_output_blks
);
(
mfma_type
.
m
_per_blk
*
mfma_type
.
n
_per_blk
*
mfma_type
.
num_output_blks
);
}
};
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
62ebdfde
...
...
@@ -3,7 +3,7 @@
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
//
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment