Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
20ddaeba
Commit
20ddaeba
authored
Apr 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c5f1cdf7
43879b89
Changes
238
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4612 additions
and
788 deletions
+4612
-788
example/ck_tile/remod.py
example/ck_tile/remod.py
+21
-0
include/ck/tensor_description/multi_index_transform.hpp
include/ck/tensor_description/multi_index_transform.hpp
+85
-0
include/ck/tensor_description/multi_index_transform_helper.hpp
...de/ck/tensor_description/multi_index_transform_helper.hpp
+6
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+354
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
...ion/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
+167
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+732
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+1154
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+439
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+597
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+667
-0
include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp
.../tensor_operation/gpu/device/device_elementwise_scale.hpp
+5
-1
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+43
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp
.../gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp
+59
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp
...or_operation/gpu/device/device_grouped_gemm_multi_abd.hpp
+98
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp
...ion/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp
+81
-0
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+94
-72
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+8
-4
include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp
..._operation/gpu/device/impl/device_elementwise_2d_impl.hpp
+0
-338
include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp
..._operation/gpu/device/impl/device_elementwise_3d_impl.hpp
+0
-371
No files found.
Too many changes to show.
To preserve performance only
238 of 238+
files are displayed.
Plain diff
Email patch
example/ck_tile/remod.py
0 → 100644
View file @
20ddaeba
import
pathlib
from
pathlib
import
Path
import
subprocess
import
os
import
copy
all_files
=
[]
for
p
in
sorted
(
Path
(
"./"
).
rglob
(
"*"
)):
if
p
.
suffix
in
[
'.hpp'
,
'.cpp'
]:
all_files
.
append
(
pathlib
.
PurePath
(
p
))
# formatting
for
x
in
all_files
:
subprocess
.
Popen
(
f
'dos2unix
{
str
(
x
)
}
'
,
shell
=
True
)
cmd
=
f
'clang-format-12 -style=file -i
{
str
(
x
)
}
'
#for xp in x.parents:
#print(get_file_base(x))
subprocess
.
Popen
(
cmd
,
shell
=
True
)
#print(all_files)
include/ck/tensor_description/multi_index_transform.hpp
View file @
20ddaeba
...
...
@@ -1951,4 +1951,89 @@ struct Modulo
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
struct
Xor
{
using
LowerIndex
=
MultiIndex
<
2
>
;
using
UpperIndex
=
MultiIndex
<
2
>
;
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Xor
()
:
up_lengths_
{}
{}
__host__
__device__
constexpr
Xor
(
const
LowLengths
&
low_lengths
)
:
up_lengths_
{
low_lengths
}
{}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
2
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
2
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
2
&&
UpIdx
::
Size
()
==
2
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
idx_low
(
Number
<
1
>
{})
=
idx_up
[
Number
<
1
>
{}]
^
(
idx_up
[
Number
<
0
>
{}]
%
up_lengths_
[
Number
<
1
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
2
&&
UpIdxDiff
::
Size
()
==
2
&&
LowIdx
::
Size
()
==
2
&&
UpIdx
::
Size
()
==
2
,
"wrong! inconsistent # of dimension"
);
const
auto
idx_low_old
=
idx_low
;
CalculateLowerIndex
(
idx_low
,
idx_up
);
idx_diff_low
=
idx_low
-
idx_low_old
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"Xor{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
}
};
}
// namespace ck
include/ck/tensor_description/multi_index_transform_helper.hpp
View file @
20ddaeba
...
...
@@ -127,4 +127,10 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
{
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
return
Xor
<
LowLengths
>
{
low_lengths
};
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
20ddaeba
...
...
@@ -960,13 +960,13 @@ struct BlockwiseGemmXdlops_pipeline_v4
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}),
make_tuple
(
Number
<
KPack
>
{},
Number
<
K
Pack
*
MRepeat
*
KPack
>
{},
Number
<
MRepeat
*
KPack
>
{},
I1
));
Number
<
KPack
>
{},
Number
<
K
Repeat
*
MRepeat
*
KPack
>
{},
Number
<
MRepeat
*
KPack
>
{},
I1
));
// B[N0, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}),
make_tuple
(
Number
<
KPack
>
{},
Number
<
K
Pack
*
M
Repeat
*
KPack
>
{},
Number
<
M
Repeat
*
KPack
>
{},
I1
));
Number
<
KPack
>
{},
Number
<
K
Repeat
*
N
Repeat
*
KPack
>
{},
Number
<
N
Repeat
*
KPack
>
{},
I1
));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeC
=
false
>
struct
BlockwiseGemmXdlops_pipeline_base
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
A_K0
=
ATileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BTileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BTileDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
ComputeDataType
,
MPerXDL
,
NPerXDL
,
KPack
,
ComputeDataType
,
TransposeC
>
{};
static
constexpr
index_t
AMmaKStride
=
KPack
;
static
constexpr
index_t
BMmaKStride
=
KPack
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
using
HotLoopInstList
=
ck
::
BlockwiseGemmXdlops_pipeline_hotloop_inst
<
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
A_K1
,
B_K1
,
A_K1
,
B_K1
,
MRepeat
,
NRepeat
,
MPerXDL
,
NPerXDL
,
xdlops_gemm
.
KPerXdlops
>
;
static_assert
(
KPerThread
%
KPack
==
0
,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack"
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
m0
,
n0
,
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_pipeline_base
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}),
make_tuple
(
Number
<
KPack
>
{},
Number
<
KRepeat
*
MRepeat
*
KPack
>
{},
Number
<
MRepeat
*
KPack
>
{},
I1
));
// B[N0, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPack
>
{}),
make_tuple
(
Number
<
KPack
>
{},
Number
<
KRepeat
*
NRepeat
*
KPack
>
{},
Number
<
NRepeat
*
KPack
>
{},
I1
));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
constexpr
auto
BlockGemmPipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
return
BlockwiseGemmXdlops_pipeline_v1
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_v2
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
BlockwiseGemmXdlops_pipeline_v3
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
return
BlockwiseGemmXdlops_pipeline_v4
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v5
)
{
return
BlockwiseGemmXdlops_pipeline_v5
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v1
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v1
<
BlockGemmPipelineScheduler
::
Interwave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
static
constexpr
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// but except the first, as we can shorten non-MAC cluster a bit and there's no
// observable negative impact. The desired effect is waves in a workgroup
// executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
// resource from other workgroups and reducing the chance of latency hiding by
// waiting for the rest of the workgroup at the eventual sync point.
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
// applying small delays to different wavefronts It is performed
// near the end of MAC cluster to minimize lgkmcnt penalty
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
}
protected:
// K->M loopover
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
MRepeat
*
KPerInnerLoop
>
{},
Number
<
MRepeat
*
KPerInnerLoop
>
{},
I1
));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
NRepeat
*
KPerInnerLoop
>
{},
Number
<
NRepeat
*
KPerInnerLoop
>
{},
I1
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v2
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
(
4
*
warpSize
/
BlockSize
),
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
// -------------------------------------------------------------------------------------------
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
LoopTailFunc
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
iprefetch
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
iprefetch
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
iprefetch
);
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
LoopTailFunc
(
Number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
LoopTailFunc
(
Number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
LoopTailFunc
(
Number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
LoopTailFunc
(
Number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
LoopTailFunc
(
Number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
LoopTailFunc
(
Number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
LoopTailFunc
(
Number
<
PrefetchStages
>
{});
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v2
<
BlockGemmPipelineScheduler
::
Interwave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
static
constexpr
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPerInnerLoop
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
math
::
integer_divide_ceil
(
32768
/
(
4
*
warpSize
/
BlockSize
),
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch [2, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
iprefetch
)
{
// -------------------------------------------------------------------------------------------
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
// block_sync_lds();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
Number
<
(
iprefetch
+
1
)
%
PrefetchStages
>
{});
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
iprefetch
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
iprefetch
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
LoopTailFunc
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
iprefetch
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
iprefetch
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
iprefetch
);
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
k0
.
value
!=
0
||
KRepeat
==
1
)
{
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
if
constexpr
(
k0
.
value
==
KRepeat
-
1
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
LoopTailFunc
(
Number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
LoopTailFunc
(
Number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
LoopTailFunc
(
Number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
LoopTailFunc
(
Number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
LoopTailFunc
(
Number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
LoopTailFunc
(
Number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
LoopTailFunc
(
Number
<
PrefetchStages
>
{});
}
}
protected:
// K->M loopover
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
MRepeat
*
KPerInnerLoop
>
{},
Number
<
MRepeat
*
KPerInnerLoop
>
{},
I1
));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
KPerInnerLoop
>
{}),
make_tuple
(
Number
<
KPerInnerLoop
>
{},
Number
<
KRepeat
*
NRepeat
*
KPerInnerLoop
>
{},
Number
<
NRepeat
*
KPerInnerLoop
>
{},
I1
));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v3
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v3
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
}
__device__
static
constexpr
auto
HotLoopScheduler
()
{
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_a_issue_cycle
-
1
)
/
ds_read_a_issue_cycle
;
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_b_issue_cycle
-
1
)
/
ds_read_b_issue_cycle
;
// stage 1
// Separate this part?
constexpr
auto
num_mfma_per_ds_read
=
sizeof
(
ComputeDataType
)
/
sizeof
(
ADataType
)
>
sizeof
(
ComputeDataType
)
/
sizeof
(
BDataType
)
?
sizeof
(
ComputeDataType
)
/
sizeof
(
ADataType
)
:
sizeof
(
ComputeDataType
)
/
sizeof
(
BDataType
);
constexpr
auto
num_mfma_stage1
=
num_mfma_inst
-
num_mfma_per_ds_read
*
(
num_ds_read_inst_a
/
ds_read_a_mfma_rate
+
num_ds_read_inst_b
/
ds_read_b_mfma_rate
);
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage1
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_ds_read_inst_a
/
ds_read_a_mfma_rate
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_ds_read
,
0
);
// MFMA
});
static_for
<
0
,
num_ds_read_inst_b
/
ds_read_b_mfma_rate
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_ds_read
,
0
);
// MFMA
});
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimimal pipeline with highest resource request
// GlobalPrefetchStages: 4
// LocalPreFillStages: 2
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v4
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v4
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
4
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
2
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
return
TailNumber
::
Odd
;
}
else
{
return
TailNumber
::
Even
;
}
}
template
<
typename
ScheduleGroup
>
__device__
static
constexpr
void
HotLoopScheduler
(
ScheduleGroup
schedule_group
)
{
// TODO: Take data type into consideration as pipe ver 3
// A-B splited schedule
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_issue_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_dswrite_per_issue_a
=
(
HotLoopInstList
::
A_LDS_Write_Inst_Num
+
num_issue_a
-
1
)
/
num_issue_a
;
constexpr
auto
num_dsread_per_issue_a
=
num_ds_read_inst_a
/
num_issue_a
;
constexpr
auto
num_issue_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_dswrite_per_issue_b
=
(
HotLoopInstList
::
B_LDS_Write_Inst_Num
+
num_issue_b
-
1
)
/
num_issue_b
;
constexpr
auto
num_dsread_per_issue_b
=
num_ds_read_inst_b
/
num_issue_b
;
constexpr
auto
num_mfma_per_issue
=
HotLoopInstList
::
C_MFMA_Inst_Num
/
(
num_issue_a
+
num_issue_b
);
static_for
<
0
,
num_issue_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_a
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
schedule_group
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
});
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
schedule_group
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
schedule_group
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_a
,
schedule_group
);
// MFMA
});
static_for
<
0
,
num_issue_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_b
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
schedule_group
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
});
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
schedule_group
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
schedule_group
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_b
,
schedule_group
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
a_thread_buf
),
Number
<
2
>
{}
>
a_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
),
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
),
I0
);
// Local prefill 2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
),
I1
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
),
I1
);
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
I0
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
I0
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
I0
));
});
});
});
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Global prefetch 4
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
auto
LoopFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
vmem_buf
,
auto
mfma_reg_buf
,
auto
schedule_group
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
vmem_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
vmem_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
]
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
HotLoopScheduler
(
schedule_group
);
};
LoopFunc
(
I1
,
I1
,
I0
,
I0
,
I0
,
I0
);
LoopFunc
(
I0
,
I0
,
I1
,
I1
,
I1
,
I0
);
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
ReadWriteCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
vmem_buf
,
auto
mfma_reg_buf
,
auto
schedule_group
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
HotLoopScheduler
(
schedule_group
);
};
auto
ReadCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
mfma_reg_buf
,
auto
schedule_group
)
{
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
.
At
(
lds_read_buf
),
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
HotLoopScheduler
(
schedule_group
);
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_bufs
[
mfma_reg_buf
][
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
};
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
,
I0
,
I1
);
ReadCompFunc
(
I0
,
I0
,
I1
,
I1
);
CompFunc
(
I0
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
,
I0
,
I1
);
ReadWriteCompFunc
(
I0
,
I0
,
I1
,
I1
,
I1
,
I1
);
ReadCompFunc
(
I1
,
I1
,
I0
,
I1
);
CompFunc
(
I1
);
}
}
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace
ck
{
// Compute optimized pipeline
// GlobalPrefetchStages: 3
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2
template
<
BlockGemmPipelineScheduler
BlkGemmPipelineVer
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPacks
>
struct
BlockwiseGemmXdlops_pipeline_v5
{
};
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_v5
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
B_K1
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
GetCThreadBuffer
;
using
Base
::
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
2
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
return
TailNumber
::
Odd
;
}
else
{
return
TailNumber
::
Even
;
}
}
__device__
static
constexpr
auto
HotLoopScheduler
()
{
// TODO: Take data type into consideration as pipe ver 3
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr
auto
num_ds_read_inst_a
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
HotLoopInstList
::
A_LDS_Read_Inst_Num
:
HotLoopInstList
::
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
HotLoopInstList
::
B_LDS_Read_Inst_Num
:
HotLoopInstList
::
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_write_inst_a
=
HotLoopInstList
::
A_LDS_Write_Inst_Num
;
constexpr
auto
num_ds_write_inst_b
=
HotLoopInstList
::
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst_a
=
HotLoopInstList
::
A_Buffer_Load_Inst_Num
;
constexpr
auto
num_buffer_load_inst_b
=
HotLoopInstList
::
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_mfma_inst
=
HotLoopInstList
::
C_MFMA_Inst_Num
;
constexpr
auto
mfma_cycle
=
NPerXDL
==
16
?
16
:
32
;
constexpr
auto
ds_read_a_issue_cycle
=
HotLoopInstList
::
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_b_issue_cycle
=
HotLoopInstList
::
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
8
:
4
;
constexpr
auto
ds_read_a_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_a_issue_cycle
-
1
)
/
(
2
*
ds_read_a_issue_cycle
);
constexpr
auto
ds_read_b_mfma_rate
=
(
mfma_cycle
-
8
+
ds_read_b_issue_cycle
-
1
)
/
(
2
*
ds_read_b_issue_cycle
);
constexpr
auto
num_dsread_stage1_a
=
num_ds_read_inst_a
/
KRepeat
*
(
KRepeat
-
1
);
constexpr
auto
num_dsread_stage1_b
=
num_ds_read_inst_b
/
KRepeat
*
(
KRepeat
-
1
);
constexpr
auto
num_dsread_stage3_a
=
num_ds_read_inst_a
/
KRepeat
;
constexpr
auto
num_dsread_stage3_b
=
num_ds_read_inst_b
/
KRepeat
;
constexpr
auto
num_dsread_stage1_a_mfma
=
(
num_dsread_stage1_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_stage1_b_mfma
=
(
num_dsread_stage1_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
constexpr
auto
num_dsread_stage3_a_mfma
=
(
num_dsread_stage3_a
+
ds_read_a_mfma_rate
-
1
)
/
ds_read_a_mfma_rate
;
constexpr
auto
num_dsread_stage3_b_mfma
=
(
num_dsread_stage3_b
+
ds_read_b_mfma_rate
-
1
)
/
ds_read_b_mfma_rate
;
constexpr
auto
num_mfma_stage2
=
num_mfma_inst
-
num_ds_read_inst_a
/
ds_read_a_mfma_rate
-
num_ds_read_inst_b
/
ds_read_b_mfma_rate
;
constexpr
auto
num_mfma_per_issue
=
num_mfma_stage2
/
(
num_buffer_load_inst_a
+
num_buffer_load_inst_b
);
constexpr
auto
num_dswrite_per_issue_a
=
num_ds_write_inst_a
/
num_buffer_load_inst_a
;
constexpr
auto
num_dswrite_per_issue_b
=
num_ds_write_inst_b
/
num_buffer_load_inst_b
;
// stage 1
static_for
<
0
,
num_dsread_stage1_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
if
constexpr
((
num_dsread_stage1_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_dsread_stage1_a
-
(
num_dsread_stage1_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_stage1_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
if
constexpr
((
num_dsread_stage1_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_dsread_stage1_b
-
(
num_dsread_stage1_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
// stage 2
static_for
<
0
,
num_buffer_load_inst_a
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_a
,
0
);
// MFMA
});
static_for
<
0
,
num_buffer_load_inst_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dswrite_per_issue_b
,
0
);
// MFMA
});
// stage 3
static_for
<
0
,
num_dsread_stage3_a_mfma
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
if
constexpr
((
num_dsread_stage3_a
-
(
i
+
1
)
*
ds_read_a_mfma_rate
)
>=
ds_read_a_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_a_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_dsread_stage3_a
-
(
num_dsread_stage3_a_mfma
-
1
)
*
ds_read_a_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
static_for
<
0
,
num_dsread_stage3_b_mfma
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
if
constexpr
((
num_dsread_stage3_b
-
(
i
+
1
)
*
ds_read_b_mfma_rate
)
>=
ds_read_b_mfma_rate
)
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
ds_read_b_mfma_rate
,
0
);
// DS read
}
else
{
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_dsread_stage3_b
-
(
num_dsread_stage3_b_mfma
-
1
)
*
ds_read_b_mfma_rate
,
0
);
// DS read
}
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
// IGLP COMPILER BUG:
// If comment out following scheduler barrier would cause sanity fail.
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
HasMainLoop
,
TailNumber
TailNum
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
I0
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// Local prefetch 1
block_sync_lds
();
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
auto
LoopFunc
=
[
&
](
auto
vmem_buf
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
if
constexpr
(
k0
==
(
KRepeat
-
1
))
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
vmem_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
vmem_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
vmem_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
}
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
ik
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
I0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
};
LoopFunc
(
I0
);
LoopFunc
(
I1
);
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
// tail
auto
ReadWriteCompFunc
=
[
&
](
auto
vmem_buf
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
if
constexpr
(
k0
==
(
KRepeat
-
1
))
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
,
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
,
vmem_buf
);
block_sync_lds
();
}
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
ik
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
I0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
HotLoopScheduler
();
};
auto
ReadCompFunc
=
[
&
]()
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KRepeat
-
1
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
ik
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
I0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
(
k0
+
1
)
%
KRepeat
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
ik
))
>
{}];
});
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
I0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
HotLoopScheduler
();
};
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
ReadWriteCompFunc
(
I0
);
ReadWriteCompFunc
(
I1
);
ReadCompFunc
();
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
ReadWriteCompFunc
(
I0
);
ReadCompFunc
();
}
}
protected:
// A[MRepeat, I1, I1, KPack]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
// B[NRepeat, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPack
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
ComputeDataType
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
Base
::
CalculateBThreadOriginDataIndex
()};
using
Base
::
c_thread_desc_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -13,6 +13,10 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/**
* \note This structure is deprecated (left for backwards compatibility). Please use
* DeviceElementwise from device_elementwise.hpp.
*/
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmV2
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
DsLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
DsDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeightMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsLayout
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_k_wos_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_k_c_xs_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
const
ck
::
index_t
split_k
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
GemmMultiABDDesc
{
ck
::
index_t
M_
,
N_
,
K_
;
std
::
vector
<
ck
::
index_t
>
stride_As_
;
std
::
vector
<
ck
::
index_t
>
stride_Bs_
;
std
::
vector
<
ck
::
index_t
>
stride_Ds_
;
ck
::
index_t
stride_C_
;
};
/*
* \brief Grouped Gemm Multi ABD
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam AsLayout A layouts (tuple).
* \tparam BsLayout B layouts (tuple).
* \tparam DsLayout Ds layouts (tuple).
* \tparam ELayout Output layout.
* \tparam AsDataType A data types (tuple).
* \tparam BsDataType B data types (tuple).
* \tparam DsDataType D data types (tuple).
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation C elementwise operation.
*/
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmMultiABD
:
public
BaseOperator
{
static
constexpr
index_t
NumATensor
=
AsDataType
::
Size
();
static
constexpr
index_t
NumBTensor
=
BsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
AsLayout
::
Size
()
==
AsDataType
::
Size
(),
"wrong! inconsistent NumATensor"
);
static_assert
(
BsLayout
::
Size
()
==
BsDataType
::
Size
(),
"wrong! inconsistent NumBTensor"
);
static_assert
(
DsLayout
::
Size
()
==
DsDataType
::
Size
(),
"wrong! inconsistent NumDTensor"
);
/*
* \brief Make argument pointer for grouped gemm multi abd.
*
* \param p_as A pointers to the A.
* \param p_bs A pointers to the B.
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the E.
* \param gemm_desc Gemm descriptors for each group.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
std
::
array
<
const
void
*
,
NumATensor
>>&
p_as
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumBTensor
>>&
p_bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
p_ds
,
std
::
vector
<
void
*>&
p_e
,
std
::
vector
<
GemmMultiABDDesc
>&
gemm_desc
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
c_element_op
=
CDEElementwiseOperation
{})
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
void
SetElementwiseOps
(
BaseArgument
*
p_arg
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_grouped_gemm_multi_abd.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
>
struct
GroupedGemmMultiABDKernelArgument
{
std
::
array
<
const
void
*
,
NumATensor
>
p_as_grid
;
std
::
array
<
const
void
*
,
NumBTensor
>
p_bs_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
std
::
array
<
index_t
,
NumATensor
>
StrideAs
;
std
::
array
<
index_t
,
NumBTensor
>
StrideBs
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
};
/*
* \brief Grouped Gemm Multi ABD Fixed NK
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam AsLayout A layouts (tuple).
* \tparam BsLayout B layouts (tuple).
* \tparam DsLayout Ds layouts (tuple).
* \tparam ELayout Output layout.
* \tparam AsDataType A data types (tuple).
* \tparam BsDataType B data types (tuple).
* \tparam DsDataType D data types (tuple).
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation C elementwise operation.
*/
template
<
typename
AsLayout
,
typename
BsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
AsDataType
,
typename
BsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmMultiABDFixedNK
:
DeviceGroupedGemmMultiABD
<
AsLayout
,
BsLayout
,
DsLayout
,
ELayout
,
AsDataType
,
BsDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_
1
d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_
2
d.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -80,42 +80,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
MPerThread
=
Number
<
4
>
{};
static
constexpr
index_t
MPerThread
=
MPerBlock
/
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
NPerThread
=
NPerBlock
/
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
AScalarPerVector
=
Number
<
4
>
{};
static
constexpr
auto
BScalarPerVector
=
Number
<
4
>
{};
static
constexpr
auto
CScalarPerVector
=
Number
<
4
>
{};
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_
1d
(
Desc_M
desc
_m
,
index_t
gridSize
,
index_t
blockSize
)
template
<
typename
Desc_M
_N
>
static
auto
PadDescriptor_M_
N
(
Desc_M
_N
desc
)
{
const
auto
M
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
M
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
const
auto
M
=
desc
.
GetLength
(
I0
);
const
auto
N
=
desc
.
GetLength
(
I1
);
const
auto
pad_M
=
math
::
integer_divide_ceil
(
M
,
MPerThread
)
*
MPerThread
-
M
;
const
auto
pad_N
=
math
::
integer_divide_ceil
(
N
,
NPerThread
)
*
NPerThread
-
N
;
const
auto
padded_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_right_pad_transform
(
M
,
pad_M
),
make_right_pad_transform
(
N
,
pad_N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
padded_desc
;
}
static
auto
MakeDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
,
index_t
gridSize
,
index_t
blockSize
)
static
auto
MakeDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
2
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
2
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
2
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
return
PadDescriptor_M_N
(
desc
);
}
// GridwiseGemm
...
...
@@ -166,7 +165,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
CGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
}
,
1
,
1
));
using
CGridDesc_M
_N
=
decltype
(
MakeDescriptor_M
_N
({
1
,
1
},
{
1
,
1
}));
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
GridwiseGemm
::
Problem
...
...
@@ -195,17 +194,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_grid_imag
{
p_c_grid_imag_
},
p_aux_grid
{
p_workspace
}
{
const
index_t
grid_size
=
std
::
get
<
1
>
(
GridwiseGemm
::
CalculateGridSize
(
M_
,
N_
));
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
c_grid_desc_m
=
DeviceOp
::
MakeDescriptor_M
({
M_
,
N_
},
{
StrideC_
,
I1
},
grid_size
,
BlockSize
);
c_grid_desc_m_n
=
DeviceOp
::
MakeDescriptor_M_N
({
M_
,
N_
},
{
StrideC_
,
I1
});
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
c_grid_desc_m
=
DeviceOp
::
MakeDescriptor_M
({
M_
,
N_
},
{
I1
,
StrideC_
},
grid_size
,
BlockSize
);
c_grid_desc_m_n
=
DeviceOp
::
MakeDescriptor_M_N
({
M_
,
N_
},
{
I1
,
StrideC_
});
}
p_aux_2_grid
=
p_workspace
+
GetCElementSpaceSize
(
M_
,
N_
,
StrideC_
);
...
...
@@ -220,7 +215,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_grid_imag
;
CDataType
*
p_aux_grid
;
CDataType
*
p_aux_2_grid
;
CGridDesc_M
c_grid_desc_m
;
CGridDesc_M
_N
c_grid_desc_m
_n
;
};
// Invoker
...
...
@@ -248,39 +243,62 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Subtract
=
ck
::
tensor_operation
::
element_wise
::
Subtract
;
using
GridwiseBinAdd
=
GridwiseElementwise_1D
<
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
using
Block2TileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
GridwiseBinAdd
=
GridwiseElementwise
<
Tuple
<
CGridDesc_M_N
,
CGridDesc_M_N
>
,
Tuple
<
CGridDesc_M_N
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Block2TileMap
,
Add
,
BlockSize
,
MPerBlock
,
NPerBlock
,
MPerThread
,
NPerThread
,
Sequence
<
0
,
1
>
,
Sequence
<
AScalarPerVector
,
BScalarPerVector
>
,
Sequence
<
CScalarPerVector
>>
;
Sequence
<
CScalarPerVector
>
,
I1
,
I1
>
;
using
GridwiseBinSubtract
=
GridwiseElementwise
_1D
<
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
GridwiseElementwise
<
Tuple
<
CGridDesc_M
_N
,
CGridDesc_M
_N
>
,
Tuple
<
CGridDesc_M
_N
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Block2TileMap
,
Subtract
,
BlockSize
,
MPerBlock
,
NPerBlock
,
MPerThread
,
NPerThread
,
Sequence
<
0
,
1
>
,
Sequence
<
AScalarPerVector
,
BScalarPerVector
>
,
Sequence
<
CScalarPerVector
>>
;
Sequence
<
CScalarPerVector
>
,
I1
,
I1
>
;
const
index_t
M
=
arg
.
c_grid_desc_m_n
.
GetLength
(
I0
);
const
index_t
N
=
arg
.
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
block_2_tile_map
=
Block2TileMap
(
M
,
N
);
const
auto
add_kernel
=
kernel_elementwise
_1d
<
GridwiseBinAdd
,
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
const
auto
add_kernel
=
kernel_elementwise
<
GridwiseBinAdd
,
Tuple
<
CGridDesc_M
_N
,
CGridDesc_M
_N
>
,
Tuple
<
CGridDesc_M
_N
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Block2TileMap
,
Add
>
;
const
auto
subtract_kernel
=
kernel_elementwise
_1d
<
GridwiseBinSubtract
,
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
kernel_elementwise
<
GridwiseBinSubtract
,
Tuple
<
CGridDesc_M
_N
,
CGridDesc_M
_N
>
,
Tuple
<
CGridDesc_M
_N
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Block2TileMap
,
Subtract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
...
@@ -318,11 +336,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
_n
,
arg
.
c_grid_desc_m
_n
),
make_tuple
(
arg
.
c_grid_desc_m
_n
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
arg
.
p_c_grid_real
),
block_2_tile_map
,
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -352,11 +371,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
_n
,
arg
.
c_grid_desc_m
_n
),
make_tuple
(
arg
.
c_grid_desc_m
_n
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
arg
.
p_c_grid_imag
),
block_2_tile_map
,
Add
{});
}
else
...
...
@@ -394,11 +414,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
_n
,
arg
.
c_grid_desc_m
_n
),
make_tuple
(
arg
.
c_grid_desc_m
_n
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
arg
.
p_c_grid_real
),
block_2_tile_map
,
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -428,11 +449,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
arg
.
c_grid_desc_m
,
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
),
make_tuple
(
arg
.
c_grid_desc_m
_n
,
arg
.
c_grid_desc_m
_n
),
make_tuple
(
arg
.
c_grid_desc_m
_n
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid
)),
make_tuple
(
arg
.
p_c_grid_imag
),
block_2_tile_map
,
Add
{});
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
20ddaeba
...
...
@@ -663,7 +663,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_kz_consecutive_
[
i
];
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
if
(
!
((
valid_a_vector_size
&&
valid_a_access_dim
)
||
ABlockTransferSrcScalarPerVector
==
1
))
{
valid_as_access
=
false
;
}
...
...
@@ -682,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_kz_consecutive_
[
i
];
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
if
(
!
((
valid_b_vector_size
&&
valid_b_access_dim
)
||
BBlockTransferSrcScalarPerVector
==
1
))
{
valid_bs_access
=
false
;
}
...
...
@@ -698,7 +700,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_nz_consecutive_
[
i
];
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
if
(
!
((
valid_d_vector_size
&&
valid_d_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
valid_ds_access
=
false
;
}
...
...
@@ -712,7 +715,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_nz_consecutive_
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
if
(
!
((
valid_e_vector_size
&&
valid_e_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp
deleted
100644 → 0
View file @
c5f1cdf7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
index_t
NumDim_m
,
index_t
NumDim_n
,
index_t
MPerThread
,
index_t
NPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise2dImpl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
>
{
static
constexpr
index_t
NumDim
=
NumDim_m
+
NumDim_n
;
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
auto
GenerateInDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
nullptr
);
},
Number
<
NumInput
>
{});
};
static
auto
GenerateOutDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
nullptr
);
},
Number
<
NumOutput
>
{});
};
using
InDataTypePointerTuple
=
decltype
(
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
GenerateOutDataTypePointerTuple
());
template
<
typename
Desc_MN
>
static
auto
PadDescriptor_MN_2d
(
Desc_MN
desc_mn
,
index_t
gridSize
,
index_t
blockSize
,
index_t
num_threads_m
,
index_t
num_threads_n
)
{
std
::
ignore
=
blockSize
;
std
::
ignore
=
gridSize
;
const
auto
m
=
desc_mn
.
GetLength
(
I0
);
const
auto
n
=
desc_mn
.
GetLength
(
I1
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
auto
pad_m
=
math
::
integer_least_multiple
(
m
,
loop_step_m
)
-
m
;
const
auto
pad_n
=
math
::
integer_least_multiple
(
n
,
loop_step_n
)
-
n
;
const
auto
desc_mn_pad
=
transform_tensor_descriptor
(
desc_mn
,
make_tuple
(
make_right_pad_transform
(
m
,
pad_m
),
make_right_pad_transform
(
n
,
pad_n
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
desc_mn_pad
;
}
static
auto
MakeDescriptor_MN
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
,
index_t
num_threads_m
,
index_t
num_threads_n
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumDim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
NumDim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDim_m
,
1
>::
type
();
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDim_m
,
NumDim_m
+
NumDim_n
,
1
>::
type
();
const
auto
mLengths
=
get_container_subset
(
tupleOfShape
,
mDimIds
);
const
auto
nLengths
=
get_container_subset
(
tupleOfShape
,
nDimIds
);
// merge nd to 2d desc - [s0 * s1 * ...]
if
constexpr
(
NumDim
>
2
)
{
const
auto
desc_mn
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadDescriptor_MN_2d
(
desc_mn
,
gridSize
,
blockSize
,
num_threads_m
,
num_threads_n
);
}
else
return
PadDescriptor_MN_2d
(
desc
,
gridSize
,
blockSize
,
num_threads_m
,
num_threads_n
);
}
template
<
index_t
TupleSize
>
static
auto
GenerateInOutGrid2dDescTuple
(
Number
<
TupleSize
>
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
NumDim
>
2
)
{
return
MakeDescriptor_MN
({
1
,
1
},
{
1
,
1
},
1
,
1
,
1
,
1
);
}
else
{
return
MakeDescriptor_MN
({
1
},
{
1
},
1
,
1
,
1
,
1
);
};
},
Number
<
TupleSize
>
{});
};
using
OutGrid2dDescTuple
=
decltype
(
GenerateInOutGrid2dDescTuple
(
Number
<
NumOutput
>
{}));
using
InGrid2dDescTuple
=
decltype
(
GenerateInOutGrid2dDescTuple
(
Number
<
NumInput
>
{}));
using
GridwiseElementwise
=
GridwiseElementwise_2D
<
InGrid2dDescTuple
,
OutGrid2dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
MPerThread
,
NPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
:
lengths_
(
lengths
),
inStridesArray_
(
inStridesArray
),
outStridesArray_
(
outStridesArray
),
elementwise_op_
(
elementwise_op
),
blockSize_
(
256
)
{
static_assert
(
NumDim_m
>
0
,
""
);
static_assert
(
NumDim_n
>
0
,
""
);
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffers
[
I
.
value
]);
},
Number
<
NumInput
>
{});
out_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
out_dev_buffers
[
I
.
value
]);
},
Number
<
NumOutput
>
{});
}
InDataTypePointerTuple
in_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
std
::
array
<
index_t
,
NumDim
>
lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
index_t
num_threads_m
=
(
gridSize
*
arg
.
blockSize_
)
/
16
;
index_t
num_threads_n
=
16
;
auto
in_grid_2d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_MN
(
arg
.
lengths_
,
arg
.
inStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
,
num_threads_m
,
num_threads_n
);
},
Number
<
NumInput
>
{});
auto
out_grid_2d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_MN
(
arg
.
lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
,
num_threads_m
,
num_threads_n
);
},
Number
<
NumOutput
>
{});
const
auto
kernel
=
kernel_elementwise_2d
<
GridwiseElementwise
,
InGrid2dDescTuple
,
OutGrid2dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
in_grid_2d_desc_tuple
,
out_grid_2d_desc_tuple
,
arg
.
in_dev_buffers_
,
arg
.
out_dev_buffers_
,
arg
.
elementwise_op_
,
num_threads_m
,
num_threads_n
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
,
index_t
vectorDim
)
{
if
(
strides
[
vectorDim
]
==
1
&&
(
lengths
[
vectorDim
]
%
scalarPerVector
==
0
||
lengths
[
vectorDim
]
%
scalarPerVector
==
lengths
[
vectorDim
]))
{
return
true
;
}
if
(
strides
[
vectorDim
]
!=
1
&&
scalarPerVector
==
strides
[
vectorDim
])
{
return
true
;
}
return
false
;
};
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
),
NumDim_m
-
1
))
valid
=
false
;
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
),
NumDim
-
1
))
valid
=
false
;
});
return
valid
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp
deleted
100644 → 0
View file @
c5f1cdf7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
index_t
NumDim_m
,
// choose how to set dims
index_t
NumDim_n
,
index_t
NumDim_k
,
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise3dImpl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
+
NumDim_k
>
{
static
constexpr
index_t
NumDim
=
NumDim_m
+
NumDim_n
+
NumDim_k
;
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
auto
GenerateInDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
nullptr
);
},
Number
<
NumInput
>
{});
}
static
auto
GenerateOutDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
nullptr
);
},
Number
<
NumOutput
>
{});
}
using
InDataTypePointerTuple
=
decltype
(
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
GenerateOutDataTypePointerTuple
());
template
<
typename
Desc_MNK
>
static
auto
PadDescriptor_MNK
(
Desc_MNK
desc_mnk
,
index_t
gridSize
,
index_t
blockSize
,
index_t
num_threads_m
,
index_t
num_threads_n
,
index_t
num_threads_k
)
{
std
::
ignore
=
blockSize
;
std
::
ignore
=
gridSize
;
const
auto
m
=
desc_mnk
.
GetLength
(
I0
);
const
auto
n
=
desc_mnk
.
GetLength
(
I1
);
const
auto
k
=
desc_mnk
.
GetLength
(
I2
);
const
index_t
loop_step_m
=
num_threads_m
*
MPerThread
;
const
index_t
loop_step_n
=
num_threads_n
*
NPerThread
;
const
index_t
loop_step_k
=
num_threads_k
*
KPerThread
;
const
auto
pad_m
=
math
::
integer_least_multiple
(
m
,
loop_step_m
)
-
m
;
const
auto
pad_n
=
math
::
integer_least_multiple
(
n
,
loop_step_n
)
-
n
;
const
auto
pad_k
=
math
::
integer_least_multiple
(
k
,
loop_step_k
)
-
k
;
const
auto
desc_mnk_pad
=
transform_tensor_descriptor
(
desc_mnk
,
make_tuple
(
make_right_pad_transform
(
m
,
pad_m
),
make_right_pad_transform
(
n
,
pad_n
),
make_right_pad_transform
(
k
,
pad_k
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
desc_mnk_pad
;
}
static
auto
MakeDescriptor_MNK
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
,
index_t
num_threads_m
,
index_t
num_threads_n
,
index_t
num_threads_k
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumDim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
NumDim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDim_m
,
1
>::
type
();
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDim_m
,
NumDim_m
+
NumDim_n
,
1
>::
type
();
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDim_m
+
NumDim_n
,
NumDim
,
1
>::
type
();
const
auto
mLengths
=
get_container_subset
(
tupleOfShape
,
mDimIds
);
const
auto
nLengths
=
get_container_subset
(
tupleOfShape
,
nDimIds
);
const
auto
kLengths
=
get_container_subset
(
tupleOfShape
,
kDimIds
);
// merge nd to 3d desc - [s0 * s1 * ...]
if
constexpr
(
NumDim
>
3
)
{
const
auto
desc_mnk
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
PadDescriptor_MNK
(
desc_mnk
,
gridSize
,
blockSize
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
}
else
return
PadDescriptor_MNK
(
desc
,
gridSize
,
blockSize
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
}
template
<
index_t
TupleSize
>
static
auto
GenerateInOutGrid3dDescTuple
(
Number
<
TupleSize
>
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
NumDim
>
3
)
{
return
MakeDescriptor_MNK
({
1
,
1
,
1
},
{
1
,
1
,
1
},
1
,
1
,
1
,
1
,
1
);
}
else
{
return
MakeDescriptor_MNK
({
1
},
{
1
},
1
,
1
,
1
,
1
,
1
);
};
},
Number
<
TupleSize
>
{});
}
using
OutGrid3dDescTuple
=
decltype
(
GenerateInOutGrid3dDescTuple
(
Number
<
NumOutput
>
{}));
using
InGrid3dDescTuple
=
decltype
(
GenerateInOutGrid3dDescTuple
(
Number
<
NumInput
>
{}));
using
GridwiseElementwise
=
GridwiseElementwise_3D
<
InGrid3dDescTuple
,
OutGrid3dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
MPerThread
,
NPerThread
,
KPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
:
lengths_
(
lengths
),
inStridesArray_
(
inStridesArray
),
outStridesArray_
(
outStridesArray
),
elementwise_op_
(
elementwise_op
),
blockSize_
(
256
)
{
static_assert
(
NumDim_m
>
0
,
""
);
static_assert
(
NumDim_n
>
0
,
""
);
static_assert
(
NumDim_k
>
0
,
""
);
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffers
[
I
.
value
]);
},
Number
<
NumInput
>
{});
out_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
out_dev_buffers
[
I
.
value
]);
},
Number
<
NumOutput
>
{});
}
InDataTypePointerTuple
in_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
std
::
array
<
index_t
,
NumDim
>
lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
)
*
arg
.
blockSize_
;
index_t
num_threads_m
=
gridSize
/
(
16
*
16
);
index_t
num_threads_n
=
16
;
index_t
num_threads_k
=
16
;
auto
in_grid_3d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_MNK
(
arg
.
lengths_
,
arg
.
inStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
},
Number
<
NumInput
>
{});
auto
out_grid_3d_desc_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_MNK
(
arg
.
lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
gridSize
,
arg
.
blockSize_
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
},
Number
<
NumOutput
>
{});
const
auto
kernel
=
kernel_elementwise_3d
<
GridwiseElementwise
,
InGrid3dDescTuple
,
OutGrid3dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
0
,
in_grid_3d_desc_tuple
,
out_grid_3d_desc_tuple
,
arg
.
in_dev_buffers_
,
arg
.
out_dev_buffers_
,
arg
.
elementwise_op_
,
num_threads_m
,
num_threads_n
,
num_threads_k
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
if
((
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
,
index_t
vectorDim
)
{
if
(
strides
[
vectorDim
]
==
1
&&
(
lengths
[
vectorDim
]
%
scalarPerVector
==
0
||
lengths
[
vectorDim
]
%
scalarPerVector
==
lengths
[
vectorDim
]))
{
return
true
;
}
if
(
strides
[
vectorDim
]
>=
scalarPerVector
)
{
return
true
;
}
return
false
;
};
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
valid
=
valid
&&
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
),
NumDim_m
-
1
);
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
valid
=
valid
&&
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
),
NumDim
-
1
);
});
return
valid
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment