Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3f9dbcac
Commit
3f9dbcac
authored
Dec 30, 2024
by
coderfeli
Browse files
use new pipeline for b preshuffle, run ok; revert olds to fix ckprofiler
parent
54f44e62
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
3263 additions
and
326 deletions
+3263
-326
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
...gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
+3
-110
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
...gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
+527
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+45
-110
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+46
-46
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+440
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+523
-60
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+1679
-0
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
View file @
3f9dbcac
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3
_b_preshuffle
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
...
@@ -27,8 +27,6 @@ using S = ck::Sequence<Is...>;
...
@@ -27,8 +27,6 @@ using S = ck::Sequence<Is...>;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
// using I8 = int8_t;
// using I32 = int;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
FP8
=
ck
::
f8_t
;
using
FP8
=
ck
::
f8_t
;
using
F32
=
float
;
using
F32
=
float
;
...
@@ -79,109 +77,6 @@ struct MultiplyMultiply
...
@@ -79,109 +77,6 @@ struct MultiplyMultiply
};
};
// struct MultiplyMultiply
// {
// template <typename E, typename C, typename D0, typename D1>
// __host__ __device__ constexpr void
// operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
// ck::half_t& e, const float& c, const float& d0, const float& d1) const
// {
// const float x0_f = c * d0 * d1;
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
// ck::half_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
// ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::bhalf_t>(x0_f);
// }
// };
// void reinit2(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// int kinit = 0;
// for (int k = 0; k < K; k+=1) {
// // dst[n * K + k] = n;
// if(k>0 && k%128==0){
// kinit += 1;
// }
// dst[n * K + k] = k % 128 + kinit;//rand() % 5 - 2;
// }
// }
// }
// void reinit(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; k+=1) {
// dst[n * K + k] = ck::type_convert<FP8>(float(1));
// }
// }
// }
void
dump
(
FP8
*
dst
,
int
N
,
int
K
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
printf
(
"%.1f,"
,
ck
::
type_convert
<
float
>
(
dst
[
n
*
K
+
k
]));
}
printf
(
"
\n
"
);
}
}
// void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
// const int NRepeat = 1;
// const int KRepeat = 8;
// const int NWave = 4;
// const int KLane = 2;
// const int NLane = 32;
// const int KPack = 16;
// int K0 = K / (KRepeat * KLane * KPack);
// int tempn, tempk;
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; ++k) {
// int n0 = n / (NRepeat * NLane * NWave);
// int k0 = k / (KRepeat * KLane * KPack);
// tempn = n % (NRepeat * NLane * NWave);
// tempk = k % (KRepeat * KLane * KPack);
// int n1 = tempn / (NLane * NWave);
// int k1 = tempk / (KLane * KPack);
// tempn = tempn % (NLane * NWave);
// tempk = tempk % (KLane * KPack);
// int n2 = tempn / NLane;
// int k2 = tempk / KPack;
// int n3 = tempn % NLane;
// int k3 = tempk % KPack;
// int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
// + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
// + n1 * KPack * NLane * KLane * NWave * KRepeat
// + k1 * KPack * NLane * KLane * NWave
// + n2 * KPack * NLane * KLane
// + k2 * KPack * NLane
// + n3 * KPack
// + k3;
// dst[outputIndex] = src[n * K + k];
// }
// }
// }
void
preShuffleBuffer
(
const
FP8
*
src
,
int
N
,
int
K
,
FP8
*
dst
)
{
void
preShuffleBuffer
(
const
FP8
*
src
,
int
N
,
int
K
,
FP8
*
dst
)
{
const
int
NRepeat
=
1
;
const
int
NRepeat
=
1
;
const
int
KRepeat
=
8
;
const
int
KRepeat
=
8
;
...
@@ -230,7 +125,8 @@ using CDEElementOp = MultiplyMultiply;
...
@@ -230,7 +125,8 @@ using CDEElementOp = MultiplyMultiply;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// clang-format off
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...
@@ -349,10 +245,7 @@ int main(int argc, char* argv[])
...
@@ -349,10 +245,7 @@ int main(int argc, char* argv[])
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// reinit2(a0_m_k.mData.data(), M, K);
// reinit2(b0_k_n.mData.data(), N, K);
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
N
,
K
,
b0_preshuffled
.
mData
.
data
());
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
N
,
K
,
b0_preshuffled
.
mData
.
data
());
// dump(b0_preshuffled.mData.data(), N, K);
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
0 → 100644
View file @
3f9dbcac
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
3f9dbcac
...
@@ -281,8 +281,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -281,8 +281,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf0
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
&
a_block_buf1
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
...
@@ -301,17 +300,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -301,17 +300,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Global prefetch 1
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
0
>
{}
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// // Local prefill 1
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf0
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
//
//
Global prefetch 2
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
...
@@ -322,12 +325,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -322,12 +325,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
0
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
// main body
...
@@ -336,61 +348,13 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -336,61 +348,13 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t
i
=
0
;
index_t
i
=
0
;
do
do
{
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf1
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf0
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
0
>
{}
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
...
@@ -399,12 +363,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -399,12 +363,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -428,75 +395,43 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -428,75 +395,43 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
0
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
});
});
});
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
}
}
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
;
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf1
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -520,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -520,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
protected:
protected:
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_copy_
;
using
Base
::
a_thread_desc_
;
using
Base
::
a_thread_desc_
;
//
using Base::b_thread_copy_;
using
Base
::
b_thread_copy_
;
using
Base
::
b_thread_desc_
;
using
Base
::
b_thread_desc_
;
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
3f9dbcac
...
@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
...
@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
// Tail number could be Odd or Even
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
//
if(arg.KBatch > 1)
if
(
arg
.
KBatch
>
1
)
//
{
{
//
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
//
{
{
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
GridwiseGemm,
GridwiseGemm
,
//
true,
true
,
//
InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum
::
AtomicAdd
,
//
minimum_occupancy,
minimum_occupancy
,
//
TailNumber::Odd>;
TailNumber
::
Odd
>
;
//
Run(kernel);
Run
(
kernel
);
//
}
}
//
else
else
//
{
{
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
GridwiseGemm,
GridwiseGemm
,
//
true,
true
,
//
InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum
::
AtomicAdd
,
//
minimum_occupancy,
minimum_occupancy
,
//
TailNumber::Even>;
TailNumber
::
Even
>
;
//
Run(kernel);
Run
(
kernel
);
//
}
}
//
}
}
//
else
else
//
{
{
//
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
//
{
{
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
GridwiseGemm,
GridwiseGemm
,
//
true,
true
,
//
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum
::
Set
,
//
minimum_occupancy,
minimum_occupancy
,
//
TailNumber::Odd>;
TailNumber
::
Odd
>
;
//
Run(kernel);
Run
(
kernel
);
//
}
}
//
else
else
//
{
{
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
GridwiseGemm,
GridwiseGemm
,
//
true,
true
,
//
InMemoryDataOperationEnum::Set,
InMemoryDataOperationEnum
::
Set
,
//
minimum_occupancy,
minimum_occupancy
,
//
TailNumber::Even>;
TailNumber
::
Even
>
;
//
Run(kernel);
Run
(
kernel
);
//
}
}
//
}
}
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
0 → 100644
View file @
3f9dbcac
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
3f9dbcac
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
0 → 100644
View file @
3f9dbcac
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment