Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed3c27cc
Commit
ed3c27cc
authored
Jul 26, 2022
by
Chao Liu
Browse files
update gemm and batch gemm with e permute
parent
dfbb659a
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
810 additions
and
337 deletions
+810
-337
example/24_batched_gemm_c_permute/CMakeLists.txt
example/24_batched_gemm_c_permute/CMakeLists.txt
+0
-2
example/24_batched_gemm_e_permute/CMakeLists.txt
example/24_batched_gemm_e_permute/CMakeLists.txt
+2
-0
example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp
...atched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp
+41
-40
example/25_gemm_bias_c_permute/CMakeLists.txt
example/25_gemm_bias_c_permute/CMakeLists.txt
+0
-1
example/25_gemm_bias_e_permute/CMakeLists.txt
example/25_gemm_bias_e_permute/CMakeLists.txt
+1
-0
example/25_gemm_bias_e_permute/gemm_bias_e_permute_xdl_fp16.cpp
...e/25_gemm_bias_e_permute/gemm_bias_e_permute_xdl_fp16.cpp
+2
-2
example/CMakeLists.txt
example/CMakeLists.txt
+2
-2
include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp
...or_operation/gpu/device/device_batched_gemm_e_permute.hpp
+6
-12
include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp
...peration/gpu/device/device_batched_gemm_e_permute_xdl.hpp
+669
-0
include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp
...ensor_operation/gpu/device/device_gemm_bias_e_permute.hpp
+0
-6
include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp
...r_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp
+84
-269
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+3
-3
No files found.
example/24_batched_gemm_c_permute/CMakeLists.txt
deleted
100644 → 0
View file @
dfbb659a
add_example_executable
(
example_batched_gemm_c_permute_xdl_fp16 batched_gemm_c_permute_xdl_fp16.cpp
)
example/24_batched_gemm_e_permute/CMakeLists.txt
0 → 100644
View file @
ed3c27cc
add_example_executable
(
example_batched_gemm_e_permute_xdl_fp16 batched_gemm_e_permute_xdl_fp16.cpp
)
example/24_batched_gemm_
c
_permute/batched_gemm_
c
_permute_xdl_fp16.cpp
→
example/24_batched_gemm_
e
_permute/batched_gemm_
e
_permute_xdl_fp16.cpp
View file @
ed3c27cc
...
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_
c
_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_
e
_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -28,33 +28,33 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
EDataType
=
ck
::
half_t
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
C
DE
ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemm
C
PermuteXdl
//######| ALayout| BLayout| AData| BData|
C
Data|
Acc
Data| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16,
F16
,
F32
, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
<
Row
,
Col
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemm
E
PermuteXdl
//######| ALayout| BLayout| AData| BData|
Acc
Data|
CShuffle| E
Data| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type|
Type|
Data|
Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | |
|
Type|
| Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | |
|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16,
F32
,
F16, F16
, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
<
Row
,
Col
,
F16
,
F16
,
F32
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
C
DataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
E
DataType
,
AElementOp
,
BElementOp
,
C
DE
ElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
}
// GEMM shape
ck
::
tensor_operation
::
device
::
BatchedGemm
C
PermuteDesc
batched_gemm_
c
_permute_desc
{
ck
::
tensor_operation
::
device
::
BatchedGemm
E
PermuteDesc
batched_gemm_
e
_permute_desc
{
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
};
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count_
,
...
...
@@ -118,7 +118,7 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
BLayout
{}));
auto
f_host_
c
_tensor_descriptor
=
[](
std
::
size_t
G0_
,
auto
f_host_
e
_tensor_descriptor
=
[](
std
::
size_t
G0_
,
std
::
size_t
G1_
,
std
::
size_t
M_
,
std
::
size_t
N_
,
...
...
@@ -131,15 +131,15 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
({
stride_G0_
,
stride_G1_
,
stride_M_
,
stride_N_
}));
};
Tensor
<
C
DataType
>
c
_g0_g1_m_n_host_result
(
f_host_
c
_tensor_descriptor
(
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
Tensor
<
E
DataType
>
e
_g0_g1_m_n_host_result
(
f_host_
e
_tensor_descriptor
(
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
Tensor
<
C
DataType
>
c
_g0_g1_m_n_device_result
(
f_host_
c
_tensor_descriptor
(
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
Tensor
<
E
DataType
>
e
_g0_g1_m_n_device_result
(
f_host_
e
_tensor_descriptor
(
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c
_g0_g1_m_n: "
<<
c
_g0_g1_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
e
_g0_g1_m_n: "
<<
e
_g0_g1_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -156,15 +156,15 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c
_device_buf
(
sizeof
(
C
DataType
)
*
c
_g0_g1_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e
_device_buf
(
sizeof
(
E
DataType
)
*
e
_g0_g1_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c
de
_element_op
=
C
DE
ElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
...
...
@@ -172,16 +172,16 @@ int main(int argc, char* argv[])
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
C
DataType
*>
(
c
_device_buf
.
GetDeviceBuffer
()),
static_cast
<
E
DataType
*>
(
e
_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
stride_A
,
stride_B
,
batched_gemm_
c
_permute_desc
,
batched_gemm_
e
_permute_desc
,
a_element_op
,
b_element_op
,
c_element_op
,
c
de
_element_op
,
batch_count
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
...
...
@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
batch_count
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
batch_count
*
M
*
K
+
sizeof
(
BDataType
)
*
batch_count
*
K
*
N
+
sizeof
(
C
DataType
)
*
batch_count
*
M
*
N
;
sizeof
(
E
DataType
)
*
batch_count
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
@@ -209,16 +209,16 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
c
_device_buf
.
FromDevice
(
c
_g0_g1_m_n_device_result
.
mData
.
data
());
e
_device_buf
.
FromDevice
(
e
_g0_g1_m_n_device_result
.
mData
.
data
());
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
Tensor
<
C
DataType
>
c_g_m_n_host_result
=
HostTensorDescriptor
(
Tensor
<
E
DataType
>
c_g_m_n_host_result
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
M
,
N
}),
std
::
vector
<
std
::
size_t
>
({
M
*
N
,
N
,
1
}));
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_g_m_k
,
b_g_k_n
,
c_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_g_m_k
,
b_g_k_n
,
c_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
ref_invoker
.
Run
(
ref_argument
);
...
...
@@ -231,14 +231,15 @@ int main(int argc, char* argv[])
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
int
g
=
g0
*
G1
+
g1
;
c_g0_g1_m_n_host_result
(
g0
,
g1
,
m
,
n
)
=
c_g_m_n_host_result
(
g
,
m
,
n
);
e_g0_g1_m_n_host_result
(
g0
,
g1
,
m
,
n
)
=
c_g_m_n_host_result
(
g
,
m
,
n
);
}
}
}
}
pass
=
ck
::
utils
::
check_err
(
c
_g0_g1_m_n_host_result
.
mData
,
c
_g0_g1_m_n_device_result
.
mData
,
pass
=
ck
::
utils
::
check_err
(
e
_g0_g1_m_n_host_result
.
mData
,
e
_g0_g1_m_n_device_result
.
mData
,
"Error: Incorrect results c"
);
}
...
...
example/25_gemm_bias_c_permute/CMakeLists.txt
deleted
100644 → 0
View file @
dfbb659a
add_example_executable
(
example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp
)
example/25_gemm_bias_e_permute/CMakeLists.txt
0 → 100644
View file @
ed3c27cc
add_example_executable
(
example_gemm_bias_e_permute_xdl_fp16 gemm_bias_e_permute_xdl_fp16.cpp
)
example/25_gemm_bias_
c
_permute/gemm_bias_
c
_permute_xdl_fp16.cpp
→
example/25_gemm_bias_
e
_permute/gemm_bias_
e
_permute_xdl_fp16.cpp
View file @
ed3c27cc
...
...
@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_
c
_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_
e
_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
...
...
@@ -49,7 +49,7 @@ using CDEElementOp = Add;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmBias
C
Permute_Xdl
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmBias
E
Permute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...
...
example/CMakeLists.txt
View file @
ed3c27cc
...
...
@@ -38,8 +38,8 @@ add_subdirectory(20_convnd_bwd_weight)
add_subdirectory
(
21_gemm_layernorm
)
add_subdirectory
(
22_cgemm
)
add_subdirectory
(
23_softmax
)
add_subdirectory
(
24_batched_gemm_
c
_permute
)
add_subdirectory
(
25_gemm_bias_
c
_permute
)
add_subdirectory
(
24_batched_gemm_
e
_permute
)
add_subdirectory
(
25_gemm_bias_
e
_permute
)
add_subdirectory
(
26_contraction
)
add_subdirectory
(
27_layernorm
)
add_subdirectory
(
28_group_convnd_fwd_bias_relu
)
include/ck/tensor_operation/gpu/device/device_batched_gemm_
c
_permute.hpp
→
include/ck/tensor_operation/gpu/device/device_batched_gemm_
e
_permute.hpp
View file @
ed3c27cc
...
...
@@ -8,7 +8,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
struct
BatchedGemm
C
PermuteDesc
struct
BatchedGemm
E
PermuteDesc
{
ck
::
index_t
G0_
,
G1_
,
M_
,
N_
;
ck
::
index_t
stride_G0_
,
stride_G1_
,
stride_M_
,
stride_N_
;
...
...
@@ -16,33 +16,27 @@ struct BatchedGemmCPermuteDesc
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemm
C
Permute
:
public
BaseOperator
typename
C
DE
ElementwiseOperation
>
struct
DeviceBatchedGemm
E
Permute
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_
c
,
void
*
p_
e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
BatchedGemm
C
PermuteDesc
batched_gemm_
c
_permute_desc
,
BatchedGemm
E
PermuteDesc
batched_gemm_
e
_permute_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C
DE
ElementwiseOperation
c
de
_element_op
,
ck
::
index_t
BatchCount
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmCPermutePtr
=
std
::
unique_ptr
<
DeviceBatchedGemmCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_
c
_permute_xdl.hpp
→
include/ck/tensor_operation/gpu/device/device_batched_gemm_
e
_permute_xdl.hpp
View file @
ed3c27cc
...
...
@@ -7,8 +7,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_
c
_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_
e
_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -24,9 +25,10 @@ namespace device {
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
* limitations.
*
* \tparam Block2
C
TileMap Block2
C
TileMap::CalculateBottomIndex() takes in id of a workgroup and
* \tparam Block2
E
TileMap Block2
E
TileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
...
...
@@ -37,40 +39,40 @@ namespace device {
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2
C
TileMap allows customized mapping between a workgroup and the C-tile it computes.
* \note \p Block2
E
TileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
C
GridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ABDataType
,
typename
EDataType
,
typename
AGridDesc_
A
K0_M_
A
K1
,
typename
BGridDesc_
B
K0_N_
B
K1
,
typename
E
GridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2
C
TileMap
,
typename
Block2
E
TileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_
c
_permute_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_
c
_grid
,
kernel_batched_gemm_
e
_permute_xdl
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
EDataType
*
__restrict__
p_
e
_grid
,
const
index_t
batch_count
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
C
GridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AGridDesc_
A
K0_M_
A
K1
a_grid_desc_
a
k0_m_
a
k1
,
const
BGridDesc_
B
K0_N_
B
K1
b_grid_desc_
b
k0_n_
b
k1
,
const
E
GridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C
DE
ElementwiseOperation
c
de
_element_op
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2
C
TileMap
block_2_
c
tile_map
)
const
Block2
E
TileMap
block_2_
e
tile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
...
...
@@ -81,40 +83,37 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c
_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
e
_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
ck
::
Tuple
<>
{},
p_
c
_grid
+
c
_batch_offset
,
p_
e
_grid
+
e
_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
0
>
{},
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ck
::
Tuple
<>
{},
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_
c
_grid
;
ignore
=
p_
e
_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c
_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_grid_desc_
a
k0_m_
a
k1
;
ignore
=
b_grid_desc_
b
k0_n_
b
k1
;
ignore
=
e
_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c
de
_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_
c
tile_map
;
ignore
=
block_2_
e
tile_map
;
#endif
}
...
...
@@ -122,51 +121,57 @@ template <typename ALayout,
typename
BLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
index_t
NumPrefetch
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLds
Add
ExtraM
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLds
Add
ExtraN
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemm
C
PermuteXdl
:
public
DeviceBatchedGemm
C
Permute
<
AElementwiseOperation
,
struct
DeviceBatchedGemm
E
PermuteXdl
:
public
DeviceBatchedGemm
E
Permute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
C
DE
ElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmEPermuteXdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -181,95 +186,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
static
auto
MakeBGridDescriptor_
BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_
N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -284,142 +204,16 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
static
auto
Make
C
GridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_M
,
index_t
stride_N
)
Make
E
GridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_M
,
index_t
stride_N
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
stride_M
,
stride_N
));
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
stride_M
,
stride_N
));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeEGridDescriptor_G0_G1_M_N
(
index_t
G0
,
...
...
@@ -489,9 +283,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}
using
AGridDesc_
K0_M_K1
=
decltype
(
MakeAGridDescriptor_
AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_
K0_N_K1
=
decltype
(
MakeBGridDescriptor_
BK0_N_BK1
(
1
,
1
,
1
));
using
C
GridDesc_M_N
=
decltype
(
Make
C
GridDescriptor_M_N
(
1
,
1
,
1
,
1
));
using
AGridDesc_
M_K
=
decltype
(
MakeAGridDescriptor_
M_K
(
1
,
1
,
1
));
using
BGridDesc_
N_K
=
decltype
(
MakeBGridDescriptor_
N_K
(
1
,
1
,
1
));
using
E
GridDesc_M_N
=
decltype
(
Make
E
GridDescriptor_M_N
(
1
,
1
,
1
,
1
));
using
EGridDesc_G0_G1_M_N
=
decltype
(
MakeEGridDescriptor_G0_G1_M_N
(
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
));
struct
ComputePtrOffsetOfStridedBatch
...
...
@@ -529,19 +323,20 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n_
;
};
using
GridwiseGemm
=
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_
xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
//
CShuffleDataType,
CShuffleDataType
,
ck
::
Tuple
<>
,
// DsDataType,
C
DataType
,
// EDataType,
E
DataType
,
// EDataType,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C
DE
ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AGridDesc_M_K
,
BGridDesc_N_K
,
Tuple
<>
,
EGridDesc_M_N
,
NumPrefetch
,
BlockSize
,
MPerBlock
,
...
...
@@ -560,7 +355,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLds
Add
ExtraM
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
...
...
@@ -568,118 +363,135 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLds
Add
ExtraN
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}));
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
C
DataType
*
p_
c
_grid
,
E
DataType
*
p_
e
_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
BatchedGemm
C
PermuteDesc
batched_gemm_
c
_permute_desc
,
BatchedGemm
E
PermuteDesc
batched_gemm_
e
_permute_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C
DE
ElementwiseOperation
c
de
_element_op
,
index_t
BatchCount
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_
c
_grid_
{
p_
c
_grid
},
p_
e
_grid_
{
p_
e
_grid
},
BatchCount_
(
BatchCount
),
a_grid_desc_k0_m_k1_
{
DeviceBatchedGemmCPermuteXdl
::
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
stride_A
)},
b_grid_desc_k0_n_k1_
{
DeviceBatchedGemmCPermuteXdl
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
stride_B
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmCPermuteXdl
::
MakeCGridDescriptor_M_N
(
batched_gemm_c_permute_desc
.
M_
,
batched_gemm_c_permute_desc
.
N_
,
batched_gemm_c_permute_desc
.
stride_M_
,
batched_gemm_c_permute_desc
.
stride_N_
)},
e_grid_desc_g0_g1_m_n_
{
DeviceBatchedGemmCPermuteXdl
::
MakeEGridDescriptor_G0_G1_M_N
(
batched_gemm_c_permute_desc
.
G0_
,
batched_gemm_c_permute_desc
.
G1_
,
batched_gemm_c_permute_desc
.
M_
,
batched_gemm_c_permute_desc
.
N_
,
batched_gemm_c_permute_desc
.
stride_G0_
,
batched_gemm_c_permute_desc
.
stride_G1_
,
batched_gemm_c_permute_desc
.
stride_M_
,
batched_gemm_c_permute_desc
.
stride_N_
)},
c_grid_desc_mblock_mperblock_nblock_nperblock
{},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
M
,
K
,
stride_A
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
K
,
N
,
stride_B
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
batched_gemm_e_permute_desc
.
M_
,
batched_gemm_e_permute_desc
.
N_
,
batched_gemm_e_permute_desc
.
stride_M_
,
batched_gemm_e_permute_desc
.
stride_N_
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
e_grid_desc_mblock_mperblock_nblock_nperblock
{},
e_grid_desc_g0_g1_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G0_G1_M_N
(
batched_gemm_e_permute_desc
.
G0_
,
batched_gemm_e_permute_desc
.
G1_
,
batched_gemm_e_permute_desc
.
M_
,
batched_gemm_e_permute_desc
.
N_
,
batched_gemm_e_permute_desc
.
stride_G0_
,
batched_gemm_e_permute_desc
.
stride_G1_
,
batched_gemm_e_permute_desc
.
stride_M_
,
batched_gemm_e_permute_desc
.
stride_N_
)},
compute_ptr_offset_of_batch_
{
type_convert
<
index_t
>
(
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
a_grid_desc_
a
k0_m_
a
k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_
b
k0_n_
b
k1_
.
GetElementSpaceSize
()),
e_grid_desc_g0_g1_m_n_
},
block_2_
c
tile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
c
_grid_desc_m_n_
)},
block_2_
e
tile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e
_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c
de
_element_op_
{
c
de
_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a
_grid_desc_
k0_m
_k
1
_
,
b_grid_desc_k0_n_k1_
,
c
_grid_desc_m_n_
,
block_2_
c
tile_map_
))
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b
_grid_desc_
n
_k_
,
ck
::
Tuple
<>
{}
,
e
_grid_desc_m_n_
,
block_2_
e
tile_map_
))
{
c
_grid_desc_mblock_mperblock_nblock_nperblock
=
e
_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c
_grid_desc_m_n_
);
e
_grid_desc_m_n_
);
}
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
std
::
cout
<<
"C[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
EDataType
*
p_e_grid_
;
// batch count
index_t
BatchCount_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
;
// for calculating Batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
C
DE
ElementwiseOperation
c
de
_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
Device
BatchedGemmCPermuteXdl
::
Argument
;
using
Argument
=
Device
Op
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
...
...
@@ -687,26 +499,24 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
const
index_t
grid_size
=
arg
.
block_2_
c
tile_map_
.
CalculateGridSize
(
arg
.
c
_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
arg
.
block_2_
e
tile_map_
.
CalculateGridSize
(
arg
.
e
_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_gemm_
c
_permute_xdl
<
const
auto
kernel
=
kernel_batched_gemm_
e
_permute_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
C
DataType
,
remove_reference_t
<
Device
BatchedGemmCPermuteXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
Device
BatchedGemmCPermuteXdl
::
BGridDesc_K0_N_K1
>
,
E
DataType
,
remove_reference_t
<
Device
Op
::
AGridDesc_
A
K0_M_
A
K1
>
,
remove_reference_t
<
Device
Op
::
BGridDesc_
B
K0_N_
B
K1
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C
DE
ElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2
C
TileMap
>
,
remove_reference_t
<
Block2
E
TileMap
>
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -716,28 +526,26 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_
c
_grid_
,
arg
.
p_
e
_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_grid_desc_
a
k0_m_
a
k1_
,
arg
.
b_grid_desc_
b
k0_n_
b
k1_
,
arg
.
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c
de
_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_
c
tile_map_
);
arg
.
block_2_
e
tile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
...
...
@@ -756,10 +564,11 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
...
...
@@ -770,30 +579,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
C
DataType
*
p_
c
,
E
DataType
*
p_
e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
BatchedGemm
C
PermuteDesc
batched_gemm_
c
_permute_desc
,
BatchedGemm
E
PermuteDesc
batched_gemm_
e
_permute_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C
DE
ElementwiseOperation
c
de
_element_op
,
index_t
BatchCount
)
{
return
Argument
{
p_a
,
p_b
,
p_
c
,
p_
e
,
M
,
N
,
K
,
stride_A
,
stride_B
,
batched_gemm_
c
_permute_desc
,
batched_gemm_
e
_permute_desc
,
a_element_op
,
b_element_op
,
c_element_op
,
c
de
_element_op
,
BatchCount
};
}
...
...
@@ -803,30 +612,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_
c
,
void
*
p_
e
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
BatchedGemm
C
PermuteDesc
batched_gemm_
c
_permute_desc
,
BatchedGemm
E
PermuteDesc
batched_gemm_
e
_permute_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
C
DE
ElementwiseOperation
c
de
_element_op
,
index_t
BatchCount
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
C
DataType
*>
(
p_
c
),
static_cast
<
E
DataType
*>
(
p_
e
),
M
,
N
,
K
,
stride_A
,
stride_B
,
batched_gemm_
c
_permute_desc
,
batched_gemm_
e
_permute_desc
,
a_element_op
,
b_element_op
,
c_element_op
,
c
de
_element_op
,
BatchCount
);
}
...
...
@@ -842,7 +651,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedGemm
C
PermuteXdl"
str
<<
"DeviceBatchedGemm
E
PermuteXdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_gemm_bias_
c
_permute.hpp
→
include/ck/tensor_operation/gpu/device/device_gemm_bias_
e
_permute.hpp
View file @
ed3c27cc
...
...
@@ -46,12 +46,6 @@ struct DeviceGemmBiasCPermute : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmBiasCPermutePtr
=
std
::
unique_ptr
<
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_bias_
c
_permute_xdl.hpp
→
include/ck/tensor_operation/gpu/device/device_gemm_bias_
e
_permute_xdl.hpp
View file @
ed3c27cc
...
...
@@ -10,8 +10,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_
c
_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_
e
_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -35,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_bias_
c
_permute
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_bias_
e
_permute
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
...
...
@@ -99,7 +100,7 @@ template <typename ALayout,
typename
CDELayout
,
typename
ADataType
,
typename
BDataType
,
typename
Gemm
AccDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DDataType
,
typename
EDataType
,
...
...
@@ -124,33 +125,36 @@ template <typename ALayout,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmBias
C
Permute_Xdl
:
public
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
struct
DeviceGemmBias
E
Permute_Xdl
:
public
DeviceGemmBiasCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmBias
C
Permute_Xdl
;
using
DeviceOp
=
DeviceGemmBias
E
Permute_Xdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
NumDTensor
=
I1
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
constexpr
index_t
NumDTensor
=
1
;
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -165,95 +169,10 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
static
auto
MakeBGridDescriptor_
BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_
N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -268,92 +187,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
static
auto
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
d_e_grid_desc
)
...
...
@@ -370,73 +204,32 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
index_t
stride_N0
=
d_e_grid_desc
.
stride_N0_
;
index_t
stride_N1
=
d_e_grid_desc
.
stride_N1_
;
const
auto
MRaw
=
M0
*
M1
*
M2
;
const
auto
NRaw
=
N0
*
N1
;
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_m0_m1_m2_n0_n1
=
make_naive_tensor_descriptor
(
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
e_grid_desc_m0_m1_m2_n0_n1
=
make_naive_tensor_descriptor
(
make_tuple
(
M0
,
M1
,
M2
,
N0
,
N1
),
make_tuple
(
stride_M0
,
stride_M1
,
stride_M2
,
stride_N0
,
stride_N1
));
return
transform_tensor_descriptor
(
c
_grid_desc_m0_m1_m2_n0_n1
,
e
_grid_desc_m0_m1_m2_n0_n1
,
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_merge_transform
(
make_tuple
(
N0
,
N1
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
using
AGridDesc_
AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_
AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_
BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_
BK0_N_BK1
(
1
,
1
,
1
));
using
AGridDesc_
M_K
=
decltype
(
MakeAGridDescriptor_
M_K
(
1
,
1
,
1
));
using
BGridDesc_
N_K
=
decltype
(
MakeBGridDescriptor_
N_K
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
DEGridDesc_M0_M1_M2_N0_N1
{}));
using
DsGridDesc_M_N
=
Tuple
<
EGridDesc_M_N
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_
k0mk1_k0nk1_mn_
xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
Gemm
AccDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<
DDataType
>
,
EDataType
,
...
...
@@ -444,8 +237,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -480,6 +274,13 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -499,12 +300,17 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_
ak0_m_ak1
_
{
DeviceOp
::
MakeAGridDescriptor_
AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_
bk0_n_bk1
_
{
DeviceOp
::
MakeBGridDescriptor_
BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m
block_mperblock_nblock_nperblock
_
{},
a_grid_desc_
m_k
_
{
DeviceOp
::
MakeAGridDescriptor_
M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_
n_k
_
{
DeviceOp
::
MakeBGridDescriptor_
N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m
_n
_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_grid_desc
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
...
...
@@ -522,8 +328,16 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
// populate pointer, desc for Ds
// D pointer
p_ds_grid_
(
I0
)
=
static_cast
<
const
DDataType
*>
(
p_d_grid
);
// D desc
ds_grid_desc_m_n_
(
I0
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
d_grid_desc
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
...
...
@@ -531,32 +345,37 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
p_ds_grid_
(
I0
)
=
static_cast
<
const
DDataType
*>
(
p_d_grid
);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
d_grid_desc
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
I0
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
d
s
_grid_desc_m_n
_
[
I0
]
);
}
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2ETileMap
block_2_etile_map_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
...
...
@@ -569,8 +388,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
...
...
@@ -586,7 +406,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_bias_
c
_permute
<
const
auto
kernel
=
kernel_gemm_bias_
e
_permute
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
...
...
@@ -596,9 +416,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
...
...
@@ -622,18 +440,14 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
arg
.
block_2_etile_map_
);
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
...
...
@@ -651,8 +465,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
...
...
@@ -741,7 +556,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmBias
C
Permute_Xdl"
str
<<
"DeviceGemmBias
E
Permute_Xdl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ed3c27cc
...
...
@@ -205,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELay
out
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELay
out
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELay
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
...
...
@@ -329,7 +329,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
// populate pointer,
batch stride,
desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment