Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f60f9d59
Commit
f60f9d59
authored
Dec 30, 2024
by
aska-0096
Browse files
sanity pass, most tile size enabled. TODO: NWave!=4
parent
482ca684
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
556 additions
and
336 deletions
+556
-336
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
...gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
+63
-22
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
...gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
+43
-11
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+45
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+116
-62
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+10
-9
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
...instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
+102
-95
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
..._multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
+33
-46
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
+11
-11
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp
...shuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp
+11
-11
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
...profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
+74
-21
profiler/src/profile_gemm_multiply_multiply_weight_preshuffle.cpp
.../src/profile_gemm_multiply_multiply_weight_preshuffle.cpp
+4
-4
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
View file @
f60f9d59
...
@@ -78,15 +78,18 @@ struct MultiplyMultiply
...
@@ -78,15 +78,18 @@ struct MultiplyMultiply
}
}
};
};
void
preShuffleBuffer
(
const
FP8
*
src
,
int
N
,
int
K
,
FP8
*
dst
)
void
preShuffleBuffer
(
const
FP8
*
src
,
FP8
*
dst
,
int
N
,
int
K
,
int
NRepeat
,
int
KRepeat
,
int
NWave
,
int
KLane
,
int
NLane
,
int
KPack
)
{
{
const
int
NRepeat
=
4
;
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
const
int
KRepeat
=
4
;
const
int
NWave
=
2
;
const
int
KLane
=
2
;
const
int
NLane
=
32
;
const
int
KPack
=
16
;
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int
tempn
,
tempk
;
int
tempn
,
tempk
;
...
@@ -108,12 +111,30 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst)
...
@@ -108,12 +111,30 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst)
int
n3
=
tempn
%
NLane
;
int
n3
=
tempn
%
NLane
;
int
k3
=
tempk
%
KPack
;
// Kpack
int
k3
=
tempk
%
KPack
;
// Kpack
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
NRepeat
*
K0
+
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
*
NRepeat
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
NRepeat
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
k2
*
KPack
*
NLane
*
KLane
*
NWave
// switch k1, k2
k2
*
KPack
*
NLane
*
KLane
*
NWave
+
n2
*
KPack
*
NLane
*
KLane
+
+
n2
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n3
*
KPack
+
k3
;
k1
*
KPack
*
NLane
+
n3
*
KPack
+
k3
;
#if 0
int k1 = tempk / (KLane * KPack); //KRepeat
int n1 = tempn / (NLane * NWave); //NRepeat
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane; // NWave
int k2 = tempk / KPack; // KLane
int n3 = tempn % NLane; // NLane
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat +
k1 * KPack * NLane * KLane * NWave * NRepeat +
n1 * KPack * NLane * KLane * NWave +
n2 * KPack * NLane * KLane +
k2 * KPack * NLane +
n3 * KPack +
k3;
#endif
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
}
}
...
@@ -124,7 +145,7 @@ using AElementOp = PassThrough;
...
@@ -124,7 +145,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
MultiplyMultiply
;
using
CDEElementOp
=
MultiplyMultiply
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -139,10 +160,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
...
@@ -139,10 +160,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
2
56
,
256
,
128
,
3
2
,
256
,
128
,
16
,
16
,
16
,
16
,
32
,
32
,
32
,
32
,
4
,
4
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
...
@@ -245,6 +266,12 @@ int main(int argc, char* argv[])
...
@@ -245,6 +266,12 @@ int main(int argc, char* argv[])
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
break
;
default:
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
...
@@ -256,10 +283,8 @@ int main(int argc, char* argv[])
...
@@ -256,10 +283,8 @@ int main(int argc, char* argv[])
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
N
,
K
,
b0_preshuffled
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
...
@@ -274,7 +299,23 @@ int main(int argc, char* argv[])
...
@@ -274,7 +299,23 @@ int main(int argc, char* argv[])
// do GEMM
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
preshuffle_params
=
device_op
.
GetPreShuffleParameters
();
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
b0_preshuffled
.
mData
.
data
(),
N
,
K
,
preshuffle_params
[
0
],
preshuffle_params
[
1
],
preshuffle_params
[
2
],
preshuffle_params
[
3
],
preshuffle_params
[
4
],
preshuffle_params
[
5
]);
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
auto
argument
=
device_op
.
MakeArgument
(
a0_device_buf
.
GetDeviceBuffer
(),
device_op
.
MakeArgument
(
a0_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
...
@@ -300,7 +341,7 @@ int main(int argc, char* argv[])
...
@@ -300,7 +341,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
,
0
,
50
,
50
,
true
,
50
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
@@ -315,7 +356,7 @@ int main(int argc, char* argv[])
...
@@ -315,7 +356,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
,
1
,
1
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
View file @
f60f9d59
...
@@ -152,8 +152,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -152,8 +152,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
return
num_loop
%
2
==
0
?
TailNumber
::
Even
:
TailNumber
::
Odd
;
}
}
__device__
static
constexpr
auto
HotLoopScheduler
()
__device__
static
constexpr
auto
HotLoopScheduler
()
...
@@ -342,8 +342,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -342,8 +342,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
b_blockwise_copy
Number
<
0
>
{}
>
();
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
...
@@ -394,8 +395,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -394,8 +395,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
// b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0,
Number
<
1
>
{}
>
();
// k0, 0>,
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
...
@@ -435,7 +439,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -435,7 +439,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
2
));
}
}
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
...
@@ -445,8 +449,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -445,8 +449,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k0
,
0
>,
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
...
@@ -486,8 +490,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -486,8 +490,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k0
,
0
>,
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
...
@@ -510,6 +514,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
...
@@ -510,6 +514,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
// latency
// latency
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_sched_barrier(0);
}
}
else
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
n0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
}
}
protected:
protected:
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
f60f9d59
...
@@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
...
@@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGemmMultipleDSplitKBPreShuffle
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
ck
::
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
array
<
int
,
6
>
GetPreShuffleParameters
()
=
0
;
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
f60f9d59
...
@@ -10,9 +10,10 @@
...
@@ -10,9 +10,10 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_gemm_multiple_d
_xdl_cshuffle_v3
.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/flush_cache.hpp"
...
@@ -69,55 +70,17 @@ template <typename ALayout,
...
@@ -69,55 +70,17 @@ template <typename ALayout,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeB
=
ComputeTypeB
>
typename
LDSTypeB
=
ComputeTypeB
>
struct
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
struct
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
:
public
DeviceGemmMultiD_Xdl_CShuffle_V3
<
:
public
DeviceGemmMultipleDSplitKBPreShuffle
<
ALayout
,
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
DsDataType
,
CDataType
,
CDataType
,
AElementwiseOperation
,
GemmAccDataType
,
BElementwiseOperation
,
CShuffleDataType
,
CElementwiseOperation
>
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
@@ -176,6 +139,18 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -176,6 +139,18 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
std
::
array
<
int
,
6
>
GetPreShuffleParameters
()
override
{
std
::
array
<
int
,
6
>
preshuffle_params
{
NXdlPerWave
,
GridwiseGemm
::
KRepeat
,
GridwiseGemm
::
NWave
,
GridwiseGemm
::
KLane
,
GridwiseGemm
::
NLane
,
GridwiseGemm
::
KPack
};
return
preshuffle_params
;
}
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
...
@@ -278,21 +253,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -278,21 +253,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
GridwiseGemm
,
{
true
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
InMemoryDataOperationEnum
::
AtomicAdd
,
GridwiseGemm
,
minimum_occupancy
>
;
true
,
Run
(
kernel
);
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
GridwiseGemm
,
{
true
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
InMemoryDataOperationEnum
::
Set
,
GridwiseGemm
,
minimum_occupancy
>
;
true
,
Run
(
kernel
);
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
}
else
else
...
@@ -436,6 +439,57 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...
@@ -436,6 +439,57 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
}
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGemmXdlUniversal"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
std
::
string
(
CLayout
::
name
)[
0
]
<<
">"
<<
" BlkSize: "
<<
BlockSize
<<
", "
<<
"BlkTile: "
<<
MPerBlock
<<
"x"
<<
NPerBlock
<<
"x"
<<
KPerBlock
<<
", "
<<
"WaveTile: "
<<
MPerXDL
<<
"x"
<<
NPerXDL
<<
", "
<<
"WaveMap: "
<<
MXdlPerWave
<<
"x"
<<
NXdlPerWave
<<
", "
<<
"VmemReadVec: "
<<
ABlockTransferSrcScalarPerVector
<<
"x"
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
"BlkGemmPipelinePrefetchStages: "
<<
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
f60f9d59
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp"
//
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...
@@ -31,7 +31,7 @@ template <typename GridwiseGemm,
...
@@ -31,7 +31,7 @@ template <typename GridwiseGemm,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
TailNumber
TailNum
=
TailNumber
::
Even
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
...
@@ -142,7 +142,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -142,7 +142,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static_assert
(
NXdlPerWave
==
1
,
"only 1 validated now, tbd next week"
);
//
static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static
constexpr
auto
MakeDsGridPointer
()
static
constexpr
auto
MakeDsGridPointer
()
{
{
...
@@ -322,10 +322,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -322,10 +322,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
{
{
constexpr
index_t
NkSwizzle
=
BlockSize
*
KPack
;
constexpr
index_t
NkSwizzleNumber
=
Number
<
BlockSize
*
KPack
>
{};
constexpr
index_t
NkSwizzleNumber
=
Number
<
NkSwizzle
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N0
,
K0
,
NkSwizzleNumber
),
return
make_naive_tensor_descriptor
(
make_tuple
(
N0
,
K0
,
NkSwizzleNumber
),
make_tuple
(
K0
*
NkSwizzle
,
NkSwizzleNumber
,
I1
));
make_tuple
(
K0
*
NkSwizzle
Number
,
NkSwizzleNumber
,
I1
));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
...
@@ -650,7 +649,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -650,7 +649,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
k_id
*
karg
.
KRead
;
// KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0*N0
b_k_split_offset
=
k_id
*
karg
.
KRead
*
NLane
*
NWave
*
NXdlPerWave
;
}
}
if
(
k_id
<
karg
.
KBatch
-
1
)
if
(
k_id
<
karg
.
KBatch
-
1
)
...
@@ -1286,8 +1286,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1286,8 +1286,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
// N0, K0, Blocksize*KPack
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
(
NPerBlock
/
NLane
/
N
Wave
)
)
;
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NXdlPer
Wave
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
@@ -1334,7 +1335,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
...
@@ -1334,7 +1335,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
KRepeat
,
KPack
*
BlockSize
>
,
Sequence
<
NXdlPerWave
,
KRepeat
,
KPack
*
BlockSize
>
,
Sequence
<
1
,
1
,
BlockSize
>
,
// BThreadClusterLengths,
Sequence
<
1
,
1
,
BlockSize
>
,
// BThreadClusterLengths,
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferClusterArrangeOrder,
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferClusterArrangeOrder,
BDataType
,
BDataType
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp
View file @
f60f9d59
...
@@ -20,7 +20,7 @@ namespace instance {
...
@@ -20,7 +20,7 @@ namespace instance {
#if 0
#if 0
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK
BPreShuffle
<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
Row,
Row,
...
@@ -33,7 +33,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
...
@@ -33,7 +33,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances);
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK
BPreShuffle
<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
Row,
Row,
...
@@ -46,7 +46,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
...
@@ -46,7 +46,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances);
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK
BPreShuffle
<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
Row,
Row,
...
@@ -59,7 +59,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
...
@@ -59,7 +59,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances);
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK
BPreShuffle
<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
Row,
Row,
...
@@ -72,7 +72,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
...
@@ -72,7 +72,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances);
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK
BPreShuffle
<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
Row,
Row,
...
@@ -85,7 +85,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
...
@@ -85,7 +85,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances);
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances(
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK
BPreShuffle
<Row,
Col,
Col,
Tuple<Row, Col>,
Tuple<Row, Col>,
Row,
Row,
...
@@ -101,82 +101,88 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
...
@@ -101,82 +101,88 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitKBPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
#endif
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -185,31 +191,32 @@ template <typename ADataType,
...
@@ -185,31 +191,32 @@ template <typename ADataType,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
>
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
struct
DeviceOperationInstanceFactory
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitKBPreShuffle
<
BLayout
,
ALayout
,
Tuple
<
Row
,
Col
>
,
BLayout
,
CLayout
,
Tuple
<
Row
,
Col
>
,
ADataType
,
CLayout
,
BDataType
,
ADataType
,
Tuple
<
F32
,
F32
>
,
BDataType
,
CDataType
,
Tuple
<
F32
,
F32
>
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
{
using
DeviceOp
=
using
DeviceOp
=
DeviceGemmMultipleDSplitK
<
ALayout
,
DeviceGemmMultipleDSplitK
BPreShuffle
<
ALayout
,
BLayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
CDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp
View file @
f60f9d59
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
View file @
f60f9d59
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
BPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp
View file @
f60f9d59
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
BPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
View file @
f60f9d59
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
BPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp
View file @
f60f9d59
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
BPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
View file @
f60f9d59
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
BPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp
View file @
f60f9d59
...
@@ -9,17 +9,17 @@ namespace device {
...
@@ -9,17 +9,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances
(
void
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
BPreShuffle
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
Row
,
Row
,
F8
,
F8
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
Tuple
<
F32
,
F32
>
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
MultiplyMultiply
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp
View file @
f60f9d59
...
@@ -24,6 +24,51 @@
...
@@ -24,6 +24,51 @@
namespace
ck
{
namespace
ck
{
namespace
profiler
{
namespace
profiler
{
template
<
typename
InOutDataType
>
void
preShuffleBuffer
(
const
InOutDataType
*
src
,
InOutDataType
*
dst
,
int
N
,
int
K
,
int
NRepeat
,
int
KRepeat
,
int
NWave
,
int
KLane
,
int
NLane
,
int
KPack
)
{
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int
tempn
,
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
int
n0
=
n
/
(
NRepeat
*
NLane
*
NWave
);
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
tempn
=
n
%
(
NRepeat
*
NLane
*
NWave
);
tempk
=
k
%
(
KRepeat
*
KLane
*
KPack
);
int
n1
=
tempn
/
(
NLane
*
NWave
);
int
k1
=
tempk
/
(
KRepeat
*
KPack
);
// Klane
tempn
=
tempn
%
(
NLane
*
NWave
);
tempk
=
tempk
%
(
KRepeat
*
KPack
);
int
n2
=
tempn
/
NLane
;
int
k2
=
tempk
/
KPack
;
// KRepeat
int
n3
=
tempn
%
NLane
;
int
k3
=
tempk
%
KPack
;
// Kpack
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
*
NRepeat
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
K0
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
k2
*
KPack
*
NLane
*
KLane
*
NWave
+
n2
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n3
*
KPack
+
k3
;
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
...
@@ -71,6 +116,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -71,6 +116,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_preshuffled
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// use layout only for size
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D0Layout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D0Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
...
@@ -125,22 +172,21 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -125,22 +172,21 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
using
DeviceOp
=
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitKBPreShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDSplitK
<
ALayout
,
ALayout
,
BLayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ELayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
EDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
>
;
CElementOp
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
@@ -188,8 +234,20 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -188,8 +234,20 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
// profile device GEMM instances
// profile device GEMM instances
for
(
auto
&
op_ptr
:
op_ptrs
)
for
(
auto
&
op_ptr
:
op_ptrs
)
{
{
// TODO: Shuffle the weight
auto
preshuffle_params
=
op_ptr
->
GetPreShuffleParameters
();
// ...
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_preshuffled
.
mData
.
data
(),
N
,
K
,
preshuffle_params
[
0
],
preshuffle_params
[
1
],
preshuffle_params
[
2
],
preshuffle_params
[
3
],
preshuffle_params
[
4
],
preshuffle_params
[
5
]);
b_device_buf
.
ToDevice
(
b_preshuffled
.
mData
.
data
());
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
};
std
::
vector
<
int
>
kbatch_list
=
{
1
,
2
,
4
,
8
,
16
};
...
@@ -224,12 +282,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
...
@@ -224,12 +282,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
// re-init C to zero before profiling next kernel
c_device_buf
.
SetZero
();
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
,
0
,
n_warmup
,
n_iter
});
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
profiler/src/profile_gemm_multiply_multiply_weight_preshuffle.cpp
View file @
f60f9d59
...
@@ -74,10 +74,10 @@ int profile_gemm_multiply_multiply_weight_preshuffle(int argc, char* argv[])
...
@@ -74,10 +74,10 @@ int profile_gemm_multiply_multiply_weight_preshuffle(int argc, char* argv[])
using
F32
=
float
;
using
F32
=
float
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
//
using F16 = ck::half_t;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
I8
=
int8_t
;
//
using I8 = int8_t;
using
I32
=
int
;
//
using I32 = int;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment