Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0dbe5370
Commit
0dbe5370
authored
Jan 02, 2025
by
aska-0096
Browse files
refine weight preshuffle format.
parent
72c1ddac
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
223 additions
and
195 deletions
+223
-195
example/65_gemm_multiply_multiply/CMakeLists.txt
example/65_gemm_multiply_multiply/CMakeLists.txt
+2
-1
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
...y_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
+31
-71
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
...gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
+65
-29
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+11
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+52
-23
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+23
-20
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
...profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
+36
-48
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+2
-2
No files found.
example/65_gemm_multiply_multiply/CMakeLists.txt
View file @
0dbe5370
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp
)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp
View file @
0dbe5370
...
@@ -39,7 +39,7 @@ using CShuffleDataType = F32;
...
@@ -39,7 +39,7 @@ using CShuffleDataType = F32;
using
D0DataType
=
F32
;
using
D0DataType
=
F32
;
using
D1DataType
=
F32
;
using
D1DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
EDataType
=
B
F16
;
using
EDataType
=
F16
;
using
A0Layout
=
Row
;
using
A0Layout
=
Row
;
using
B0Layout
=
Col
;
using
B0Layout
=
Col
;
...
@@ -97,63 +97,32 @@ struct MultiplyMultiply
...
@@ -97,63 +97,32 @@ struct MultiplyMultiply
}
}
};
};
void
preShuffleBuffer
(
const
FP8
*
src
,
void
preShuffleBuffer
(
const
FP8
*
src
,
FP8
*
dst
,
int
N
,
int
K
,
int
NXdl
)
FP8
*
dst
,
int
N
,
int
K
,
int
NRepeat
,
int
KRepeat
,
int
NWave
,
int
KLane
,
int
NLane
,
int
KPack
)
{
{
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
int
KPack
=
16
;
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
int
NLane
=
NXdl
;
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int
KLane
=
64
/
NLane
;
int
tempn
,
tempk
;
int
N0
=
N
/
NLane
;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
int
n0
=
n
/
(
NRepeat
*
NLane
*
NWave
);
int
n0
=
n
/
NLane
;
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
int
n1
=
n
%
NLane
;
tempn
=
n
%
(
NRepeat
*
NLane
*
NWave
);
tempk
=
k
%
(
KRepeat
*
KLane
*
KPack
);
int
k0
=
k
/
(
KLane
*
KPack
);
tempk
=
k
%
(
KLane
*
KPack
);
int
n1
=
tempn
/
(
NLane
*
NWave
);
int
k1
=
tempk
/
KPack
;
int
k1
=
tempk
/
(
KRepeat
*
KPack
);
// Klane
int
k2
=
tempk
%
KPack
;
tempn
=
tempn
%
(
NLane
*
NWave
);
tempk
=
tempk
%
(
KRepeat
*
KPack
);
int
outputIndex
=
k0
*
KPack
*
NLane
*
KLane
*
N0
+
n0
*
KPack
*
NLane
*
KLane
+
int
n2
=
tempn
/
NLane
;
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
int
k2
=
tempk
/
KPack
;
// KRepeat
int
n3
=
tempn
%
NLane
;
int
k3
=
tempk
%
KPack
;
// Kpack
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
*
NRepeat
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
k2
*
KPack
*
NLane
*
KLane
*
NWave
+
n2
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n3
*
KPack
+
k3
;
#if 0
int k1 = tempk / (KLane * KPack); //KRepeat
int n1 = tempn / (NLane * NWave); //NRepeat
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane; // NWave
int k2 = tempk / KPack; // KLane
int n3 = tempn % NLane; // NLane
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat +
k1 * KPack * NLane * KLane * NWave * NRepeat +
n1 * KPack * NLane * KLane * NWave +
n2 * KPack * NLane * KLane +
k2 * KPack * NLane +
n3 * KPack +
k3;
#endif
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
}
}
...
@@ -179,13 +148,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
...
@@ -179,13 +148,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
2
56
,
256
,
128
,
3
2
,
256
,
256
,
16
,
16
,
16
,
16
,
32
,
32
,
32
,
32
,
8
,
2
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...
@@ -319,18 +288,9 @@ int main(int argc, char* argv[])
...
@@ -319,18 +288,9 @@ int main(int argc, char* argv[])
// do GEMM
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
device_op
=
DeviceOpInstance
{};
auto
preshuffle_params
=
device_op
.
GetPreShuffleParameters
();
int
NPerXdl
=
device_op
.
GetPreShuffleParameters
();
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
b0_preshuffled
.
mData
.
data
(),
N
,
K
,
NPerXdl
);
b0_preshuffled
.
mData
.
data
(),
N
,
K
,
preshuffle_params
[
0
],
preshuffle_params
[
1
],
preshuffle_params
[
2
],
preshuffle_params
[
3
],
preshuffle_params
[
4
],
preshuffle_params
[
5
]);
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
View file @
0dbe5370
...
@@ -118,12 +118,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -118,12 +118,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
;
KPack
>
;
using
Base
::
A_K1
;
using
Base
::
I0
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
I1
;
using
Base
::
KRepeat
;
using
Base
::
KRepeat
;
using
Base
::
xdlops_gemm
;
using
Base
::
xdlops_gemm
;
using
typename
Base
::
HotLoopInstList
;
using
typename
Base
::
HotLoopInstList
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
CalculateCThreadOriginDataIndex8D
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
...
@@ -136,8 +138,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -136,8 +138,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
using
Base
::
BMmaKStride
;
...
@@ -145,6 +145,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -145,6 +145,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
template
<
typename
TileDesc_M0_M1_M2_K
>
__host__
__device__
static
constexpr
auto
MakeAGemmMmaTileDescriptor
(
const
TileDesc_M0_M1_M2_K
&
)
{
constexpr
index_t
M0
=
TileDesc_M0_M1_M2_K
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
M1
=
TileDesc_M0_M1_M2_K
{}.
GetLength
(
Number
<
1
>
{});
constexpr
index_t
M2
=
TileDesc_M0_M1_M2_K
{}.
GetLength
(
Number
<
2
>
{});
constexpr
index_t
K2
=
KPack
;
constexpr
index_t
K1
=
64
/
NPerXDL
;
constexpr
index_t
K0
=
KRepeat
;
return
transform_tensor_descriptor
(
TileDesc_M0_M1_M2_K
{},
make_tuple
(
make_pass_through_transform
(
Number
<
M0
>
{}),
make_pass_through_transform
(
Number
<
M1
>
{}),
make_pass_through_transform
(
Number
<
M2
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{},
Number
<
K2
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k0_k1_k2
=
MakeAGemmMmaTileDescriptor
(
a_block_desc_m0_m1_m2_k
);
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
{
return
num_loop
>
PrefetchStages
;
return
num_loop
>
PrefetchStages
;
...
@@ -275,11 +299,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -275,11 +299,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf0
,
a_block_buf0
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
});
});
...
@@ -305,12 +329,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -305,12 +329,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k
0
,
0
>,
.
template
GetSrcThreadScratchIdx
<
Sequence
<
k0
,
n0
,
0
,
0
>,
Number
<
0
>
{}
>
();
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k
0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -332,11 +356,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -332,11 +356,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf1
,
a_block_buf1
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
});
});
...
@@ -357,15 +381,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -357,15 +381,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
// b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0,
// k0, 0>,
b_blockwise_copy
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k
0
,
0
>,
.
template
GetSrcThreadScratchIdx
<
Sequence
<
k0
,
n0
,
0
,
0
>,
Number
<
1
>
{}
>
();
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k
0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -387,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -387,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf0
,
a_block_buf0
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
});
});
...
@@ -411,12 +433,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -411,12 +433,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k
0
,
0
>,
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
k0
,
n0
,
0
,
0
>,
Number
<
0
>
{}
>
();
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k
0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -436,11 +458,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -436,11 +458,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
0_k1_k2
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
I0
),
a_block_buf1
,
a_block_buf1
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
});
});
...
@@ -452,12 +474,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -452,12 +474,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k
0
,
0
>,
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
k0
,
n0
,
0
,
0
>,
Number
<
1
>
{}
>
();
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k
0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -483,12 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -483,12 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k
0
,
0
>,
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
k0
,
n0
,
0
,
0
>,
Number
<
0
>
{}
>
();
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k
0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I
0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -507,9 +530,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -507,9 +530,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
}
}
protected:
protected:
using
Base
::
a_thread_copy_
;
// MRepeat MWave MLane KRepeat KLane KPack
using
Base
::
a_thread_desc_
;
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
using
Base
::
b_thread_desc_
;
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KRepeat
>
{},
I1
,
Number
<
KPack
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ComputeDataType
,
decltype
(
a_block_desc_m0_m1_m2_k0_k1_k2
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
AThreadCopy
a_thread_copy_
{
Base
::
CalculateAThreadOriginDataIndex6D
()};
using
Base
::
c_thread_desc_
;
using
Base
::
c_thread_desc_
;
};
};
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
0dbe5370
...
@@ -113,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -113,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
}
__device__
static
auto
CalculateAThreadOriginDataIndex6D
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
0
,
xdlops_a_idx
[
I0
],
0
);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
0dbe5370
...
@@ -138,7 +138,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
...
@@ -138,7 +138,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
array
<
int
,
6
>
GetPreShuffleParameters
()
=
0
;
virtual
int
GetPreShuffleParameters
()
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
0dbe5370
...
@@ -139,16 +139,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -139,16 +139,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
std
::
array
<
int
,
6
>
GetPreShuffleParameters
()
override
int
GetPreShuffleParameters
()
override
{
{
std
::
array
<
int
,
6
>
preshuffle_params
{
NXdlPerWave
,
return
NPerXDL
;
GridwiseGemm
::
KRepeat
,
GridwiseGemm
::
NWave
,
GridwiseGemm
::
KLane
,
GridwiseGemm
::
NLane
,
GridwiseGemm
::
KPack
};
return
preshuffle_params
;
}
}
// Invoker
// Invoker
...
@@ -240,8 +233,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -240,8 +233,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
}
};
};
constexpr
index_t
minimum_occupancy
=
constexpr
index_t
minimum_occupancy
=
[]()
{
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
return
(
MPerBlock
*
NPerBlock
/
BlockSize
<=
128
)
?
2
:
1
;
}
else
{
return
1
;
}
}();
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
...
@@ -307,21 +308,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -307,21 +308,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
GridwiseGemm
,
{
false
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
InMemoryDataOperationEnum
::
AtomicAdd
,
GridwiseGemm
,
minimum_occupancy
>
;
false
,
Run
(
kernel
);
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
GridwiseGemm
,
{
false
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
InMemoryDataOperationEnum
::
Set
,
GridwiseGemm
,
minimum_occupancy
>
;
false
,
Run
(
kernel
);
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
0dbe5370
...
@@ -141,8 +141,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -141,8 +141,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static
constexpr
index_t
KRepeat
=
KPerBlock
/
KLane
/
KPack
;
static
constexpr
index_t
KRepeat
=
KPerBlock
/
KLane
/
KPack
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static_assert
(
NWave
*
warpSize
==
BlockSize
);
// static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static
constexpr
auto
MakeDsGridPointer
()
static
constexpr
auto
MakeDsGridPointer
()
{
{
...
@@ -176,7 +175,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -176,7 +175,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__
__device__
static
auto
CalculateBN0Shuffled
(
index_t
N
)
__host__
__device__
static
auto
CalculateBN0Shuffled
(
index_t
N
)
{
{
return
math
::
integer_divide_ceil
(
N
,
NLane
*
NWave
);
return
math
::
integer_divide_ceil
(
N
,
NLane
);
}
}
__host__
__device__
static
auto
CalculateBK0Shuffled
(
index_t
K
)
__host__
__device__
static
auto
CalculateBK0Shuffled
(
index_t
K
)
{
{
...
@@ -322,9 +321,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -322,9 +321,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
{
{
constexpr
index_t
NkSwizzleNumber
=
Number
<
Block
Size
*
KPack
>
{};
constexpr
index_t
NkSwizzleNumber
=
Number
<
warp
Size
*
KPack
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N0
,
K0
,
NkSwizzleNumber
),
return
make_naive_tensor_descriptor
(
make_tuple
(
K0
,
N0
/
NWave
,
NWave
,
NkSwizzleNumber
),
make_tuple
(
K0
*
NkSwizzleNumber
,
NkSwizzleNumber
,
I1
));
make_tuple
(
N0
*
NkSwizzleNumber
,
NWave
*
NkSwizzleNumber
,
NkSwizzleNumber
,
I1
));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
...
@@ -649,8 +648,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -649,8 +648,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
// KPack * NLane * KLane * N
Wave * KRepeat * K0* NRepeat
*
N
0
// KPack * NLane * KLane * N
0
*
K
0
b_k_split_offset
=
k_id
*
karg
.
KRead
*
NLane
*
NWave
;
b_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
N
;
}
}
if
(
k_id
<
karg
.
KBatch
-
1
)
if
(
k_id
<
karg
.
KBatch
-
1
)
...
@@ -1159,6 +1158,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1159,6 +1158,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
}
// check gridwise gemm pipeline
// check gridwise gemm pipeline
#if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
...
@@ -1168,7 +1168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1168,7 +1168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return false;
return false;
}
}
}
}
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
...
@@ -1252,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1252,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bpreshuffled
=
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
...
@@ -1294,7 +1295,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1294,7 +1295,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// dummy
constexpr
auto
b_block_desc_bk0_n_bk1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
));
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -1335,17 +1338,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1335,17 +1338,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
NXdlPerWave
,
KRepeat
,
KPack
*
Block
Size
>
,
Sequence
<
KRepeat
,
NXdlPerWave
,
NWave
,
KPack
*
warp
Size
>
,
Sequence
<
1
,
1
,
Block
Size
>
,
// BThreadClusterLengths,
Sequence
<
1
,
1
,
NWave
,
warp
Size
>
,
// BThreadClusterLengths,
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// BBlockTransferClusterArrangeOrder,
BDataType
,
BDataType
,
LDSTypeB
,
LDSTypeB
,
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferSrcAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// BBlockTransferSrcAccessOrder,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
2
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
...
@@ -1353,10 +1356,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1353,10 +1356,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
2
>
(
b_grid_desc_bpreshuffled
,
2
>
(
b_grid_desc_bpreshuffled
,
make_multi_index
(
n_block_data_idx_on_grid
,
0
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
,
0
),
b_element_op
,
b_element_op
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -1367,7 +1370,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1367,7 +1370,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static_cast
<
LDSTypeA
*>
(
p_shared1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
LDSTypeA
*>
(
p_shared1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KRepeat
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KRepeat
,
0
,
0
,
0
);
// Blockwise GEMM pipeline
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
...
...
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
View file @
0dbe5370
...
@@ -29,40 +29,31 @@ void preShuffleBuffer(const InOutDataType* src,
...
@@ -29,40 +29,31 @@ void preShuffleBuffer(const InOutDataType* src,
InOutDataType
*
dst
,
InOutDataType
*
dst
,
int
N
,
int
N
,
int
K
,
int
K
,
int
NRepeat
,
int
NXdl
)
int
KRepeat
,
int
NWave
,
int
KLane
,
int
NLane
,
int
KPack
)
{
{
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
int
KPack
=
16
;
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
int
NLane
=
NXdl
;
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int
KLane
=
64
/
NLane
;
int
tempn
,
tempk
;
int
N0
=
N
/
NLane
;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
int
n0
=
n
/
(
NRepeat
*
NLane
*
NWave
);
int
n0
=
n
/
NLane
;
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
int
n1
=
n
%
NLane
;
tempn
=
n
%
(
NRepeat
*
NLane
*
NWave
);
tempk
=
k
%
(
KRepeat
*
KLane
*
KPack
);
int
k0
=
k
/
(
KLane
*
KPack
);
tempk
=
k
%
(
KLane
*
KPack
);
int
n1
=
tempn
/
(
NLane
*
NWave
);
int
k1
=
tempk
/
KPack
;
int
k1
=
tempk
/
(
KRepeat
*
KPack
);
// Klane
int
k2
=
tempk
%
KPack
;
tempn
=
tempn
%
(
NLane
*
NWave
);
tempk
=
tempk
%
(
KRepeat
*
KPack
);
int
outputIndex
=
k0
*
KPack
*
NLane
*
KLane
*
N0
+
n0
*
KPack
*
NLane
*
KLane
+
int
n2
=
tempn
/
NLane
;
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
int
k2
=
tempk
/
KPack
;
// KRepeat
int
n3
=
tempn
%
NLane
;
int
k3
=
tempk
%
KPack
;
// Kpack
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
*
NRepeat
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
k2
*
KPack
*
NLane
*
KLane
*
NWave
+
n2
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n3
*
KPack
+
k3
;
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
...
@@ -116,7 +107,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -116,7 +107,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_preshuffled
(
Tensor
<
BDataType
>
b_preshuffled_mfma16
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// use layout only for size
Tensor
<
BDataType
>
b_preshuffled_mfma32
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// use layout only for size
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// use layout only for size
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D0Layout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D0Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D1Layout
{}));
...
@@ -154,6 +147,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -154,6 +147,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
}
}
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_preshuffled_mfma16
.
mData
.
data
(),
N
,
K
,
16
);
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_preshuffled_mfma32
.
mData
.
data
(),
N
,
K
,
32
);
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
...
@@ -166,12 +162,16 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -166,12 +162,16 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
const
auto
c_element_op
=
CElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_mfma16
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_mfma32
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf_mfma16
.
ToDevice
(
b_preshuffled_mfma16
.
mData
.
data
());
b_device_buf_mfma32
.
ToDevice
(
b_preshuffled_mfma32
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
...
@@ -234,20 +234,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -234,20 +234,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
// profile device GEMM instances
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
auto
preshuffle_params
=
op_ptr
->
GetPreShuffleParameters
();
int
NPerXdl
=
op_ptr
->
GetPreShuffleParameters
();
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_preshuffled
.
mData
.
data
(),
N
,
K
,
preshuffle_params
[
0
],
preshuffle_params
[
1
],
preshuffle_params
[
2
],
preshuffle_params
[
3
],
preshuffle_params
[
4
],
preshuffle_params
[
5
]);
b_device_buf
.
ToDevice
(
b_preshuffled
.
mData
.
data
());
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
};
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
};
...
@@ -262,7 +249,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -262,7 +249,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
NPerXdl
==
16
?
b_device_buf_mfma16
.
GetDeviceBuffer
()
:
b_device_buf_mfma32
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
2
>
{
d0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
2
>
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
()},
d1_device_buf
.
GetDeviceBuffer
()},
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
...
@@ -298,8 +286,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -298,8 +286,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
is_same_v
<
EDataType
,
int8_t
>
))
is_same_v
<
EDataType
,
int8_t
>
))
{
{
std
::
string
msg
=
"Error: Incorrect results!"
;
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-
1
;
double
rtol
=
1e-
3
;
double
atol
=
1
e-
1
;
double
atol
=
5
e-
2
;
pass
=
pass
&
ck
::
utils
::
check_err
(
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
msg
,
rtol
,
atol
);
e_m_n_device_result
,
e_m_n_host_result
,
msg
,
rtol
,
atol
);
}
}
...
...
profiler/src/CMakeLists.txt
View file @
0dbe5370
...
@@ -50,7 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
...
@@ -50,7 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# endif()
# endif()
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp
)
#
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp
)
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
# endif()
# endif()
...
@@ -137,7 +137,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
...
@@ -137,7 +137,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_multiply_instance
)
#
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance
)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
# endif()
# endif()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment