Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
22fe522d
Commit
22fe522d
authored
Jan 08, 2025
by
aska-0096
Browse files
optimize software pipeline
parent
9dd74e0d
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1479 additions
and
426 deletions
+1479
-426
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
...y_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
+10
-7
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
.../blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
+91
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
+506
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp
+165
-147
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+135
-27
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+510
-179
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
..._multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
+29
-29
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp
...y_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp
+29
-29
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
...profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
+4
-8
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
View file @
22fe522d
...
...
@@ -103,10 +103,10 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
int
NLane
=
NXdl
;
int
KLane
=
64
/
NLane
;
int
N
0
=
N
/
N
Lane
;
int
K
0
=
K
/
(
K
Lane
*
KPack
)
;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K ->
K
0
N
0 KLane NLane KPack
// N, K ->
N
0
K
0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
...
...
@@ -120,7 +120,7 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
int
k1
=
tempk
/
KPack
;
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
k
0
*
KPack
*
NLane
*
KLane
*
N
0
+
n
0
*
KPack
*
NLane
*
KLane
+
int
outputIndex
=
n
0
*
KPack
*
NLane
*
KLane
*
K
0
+
k
0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
...
...
@@ -148,14 +148,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
32
,
256
,
256
,
32
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
2
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
3
,
FP8
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v
1
,
FP8
>
;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...
...
@@ -366,7 +366,10 @@ int main(int argc, char* argv[])
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
?
0
:
1
;
}
return
0
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
0 → 100644
View file @
22fe522d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Single lds buffer
v2
,
// Double lds buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
AccDataType
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
constexpr
auto
BlockGemmBPreshufflePipeline_Selector
()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
return
BlockwiseGemmXdlops_pipeline_bpreshuffle_v1
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
return
BlockwiseGemmXdlops_pipeline_bpreshuffle_v2
<
BlkGemmPipeSche
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
{
std
::
cerr
<<
"BlockGemmPipeline configuration is not available"
<<
std
::
endl
;
}
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
0 → 100644
View file @
22fe522d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
→
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle
_v2
.hpp
View file @
22fe522d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
22fe522d
...
...
@@ -139,10 +139,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using
Argument
=
typename
GridwiseGemm
::
Argument
;
int
GetPreShuffleParameters
()
override
{
return
NPerXDL
;
}
int
GetPreShuffleParameters
()
override
{
return
NPerXDL
;
}
// Invoker
struct
Invoker
:
public
BaseInvoker
...
...
@@ -233,16 +230,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
};
constexpr
index_t
minimum_occupancy
=
[]()
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
(
M
PerBlock
*
N
PerBlock
/
BlockSize
<=
128
)
?
2
:
1
;
}
else
{
return
1
;
}
}()
;
constexpr
auto
estimated_reg_a
=
MPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
/
BlockSize
/
4
*
(
1
+
GridwiseGemm
::
NWave
);
constexpr
auto
estimated_reg_b
=
N
PerBlock
*
K
PerBlock
*
sizeof
(
BDataType
)
/
BlockSize
/
4
*
(
2
)
;
constexpr
auto
estimated_reg_c
=
MPerBlock
*
NPerBlock
*
sizeof
(
GemmAccDataType
)
/
BlockSize
/
4
;
constexpr
auto
estimated_reg_total
=
estimated_reg_a
+
estimated_reg_b
+
estimated_reg_c
;
constexpr
index_t
minimum_occupancy
=
(
estimated_reg_total
>=
256
)
?
1
:
2
;
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
...
...
@@ -250,7 +247,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v
3
)
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v
1
)
{
if
(
arg
.
KBatch
>
1
)
{
...
...
@@ -299,16 +296,72 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
}
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
throw
std
::
runtime_error
(
"todo: only v
3
support now"
);
throw
std
::
runtime_error
(
"todo: only v
1 & v2
support now"
);
}
}
#if 0
else
{
if
(
arg
.
KBatch
>
1
)
if
constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v
1)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
...
...
@@ -328,10 +381,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
TailNumber::Even>;
Run(kernel);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
...
...
@@ -351,8 +404,67 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
TailNumber::Even>;
Run(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
throw std::runtime_error("todo: only v3 support now");
}
}
#endif
return
ave_time
;
}
...
...
@@ -490,11 +602,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmXdlUniversal"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
22fe522d
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
View file @
22fe522d
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp
View file @
22fe522d
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
View file @
22fe522d
...
...
@@ -25,20 +25,16 @@ namespace ck {
namespace
profiler
{
template
<
typename
InOutDataType
>
void
preShuffleBuffer
(
const
InOutDataType
*
src
,
InOutDataType
*
dst
,
int
N
,
int
K
,
int
NXdl
)
void
preShuffleBuffer
(
const
InOutDataType
*
src
,
InOutDataType
*
dst
,
int
N
,
int
K
,
int
NXdl
)
{
int
KPack
=
16
;
int
NLane
=
NXdl
;
int
KLane
=
64
/
NLane
;
int
N
0
=
N
/
N
Lane
;
int
K
0
=
K
/
(
K
Lane
*
KPack
)
;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K ->
K
0
N
0 KLane NLane KPack
// N, K ->
N
0
K
0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
...
...
@@ -52,7 +48,7 @@ void preShuffleBuffer(const InOutDataType* src,
int
k1
=
tempk
/
KPack
;
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
k
0
*
KPack
*
NLane
*
KLane
*
N
0
+
n
0
*
KPack
*
NLane
*
KLane
+
int
outputIndex
=
n
0
*
KPack
*
NLane
*
KLane
*
K
0
+
k
0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment