Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e70aa44a
Commit
e70aa44a
authored
Jul 19, 2022
by
Jing Zhang
Browse files
add batch_strides into bmm_c_permute
parent
ffd6943e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
231 additions
and
181 deletions
+231
-181
example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp
...atched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp
+53
-44
example/28_batched_gemm_multi_d/batched_gemm_xdl_fp16.cpp
example/28_batched_gemm_multi_d/batched_gemm_xdl_fp16.cpp
+9
-6
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp
...or_operation/gpu/device/device_batched_gemm_c_permute.hpp
+30
-8
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp
...peration/gpu/device/device_batched_gemm_c_permute_xdl.hpp
+139
-123
No files found.
example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp
View file @
e70aa44a
...
...
@@ -26,35 +26,36 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static
constexpr
auto
MNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
//
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmCPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
<
Row
,
Col
,
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
MNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
C
DataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceBatchedGemm
<
ADataType
,
BDataType
,
E
DataType
,
AElementOp
,
BElementOp
,
C
DE
ElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -62,15 +63,18 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
bool
time_kernel
=
false
;
const
int
M
=
88
;
const
int
N
=
64
;
const
int
K
=
88
;
const
int
M
=
256
;
const
int
N
=
128
;
const
int
K
=
64
;
const
int
stride_A
=
K
;
const
int
stride_B
=
K
;
const
int
G0
=
1024
;
const
int
G1
=
10
;
const
int
batch_stride_A
=
M
*
K
;
const
int
batch_stride_B
=
K
*
N
;
const
int
G0
=
16
;
const
int
G1
=
8
;
const
int
batch_count
=
G0
*
G1
;
...
...
@@ -102,21 +106,24 @@ int main(int argc, char* argv[])
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count_
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
row
*
stride
,
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count_
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
col
*
stride
,
1
,
stride
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
BLayout
{}));
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
BLayout
{}));
auto
f_host_c_tensor_descriptor
=
[](
std
::
size_t
G0_
,
std
::
size_t
G1_
,
...
...
@@ -131,10 +138,10 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
({
stride_G0_
,
stride_G1_
,
stride_M_
,
stride_N_
}));
};
Tensor
<
C
DataType
>
c_g0_g1_m_n_host_result
(
Tensor
<
E
DataType
>
c_g0_g1_m_n_host_result
(
f_host_c_tensor_descriptor
(
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
Tensor
<
C
DataType
>
c_g0_g1_m_n_device_result
(
Tensor
<
E
DataType
>
c_g0_g1_m_n_device_result
(
f_host_c_tensor_descriptor
(
G0
,
G1
,
M
,
N
,
stride_G0
,
stride_G1
,
stride_M
,
stride_N
));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
...
...
@@ -156,32 +163,34 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
C
DataType
)
*
c_g0_g1_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
E
DataType
)
*
c_g0_g1_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c
de
_element_op
=
C
DE
ElementOp
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
// do GEM
M
// do GEM
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
C
DataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
E
DataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
stride_A
,
stride_B
,
batch_stride_A
,
batch_stride_B
,
batched_gemm_c_permute_desc
,
batch_count
,
a_element_op
,
b_element_op
,
c_element_op
,
batch_count
);
cde_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -195,7 +204,7 @@ int main(int argc, char* argv[])
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
batch_count
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
batch_count
*
M
*
K
+
sizeof
(
BDataType
)
*
batch_count
*
K
*
N
+
sizeof
(
C
DataType
)
*
batch_count
*
M
*
N
;
sizeof
(
E
DataType
)
*
batch_count
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
@@ -213,11 +222,11 @@ int main(int argc, char* argv[])
auto
ref_batched_gemm
=
ReferenceBatchedGemmInstance
{};
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
Tensor
<
C
DataType
>
c_g_m_n_host_result
=
HostTensorDescriptor
(
Tensor
<
E
DataType
>
c_g_m_n_host_result
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
M
,
N
}),
std
::
vector
<
std
::
size_t
>
({
M
*
N
,
N
,
1
}));
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_g_m_k
,
b_g_k_n
,
c_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c_element_op
);
a_g_m_k
,
b_g_k_n
,
c_g_m_n_host_result
,
a_element_op
,
b_element_op
,
c
de
_element_op
);
ref_invoker
.
Run
(
ref_argument
);
...
...
example/28_batched_gemm_multi_d/batched_gemm_xdl_fp16.cpp
View file @
e70aa44a
...
...
@@ -96,24 +96,27 @@ int main(int argc, char* argv[])
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
std
::
size_t
batch_stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count_
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
row
*
stride
,
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count_
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
col
*
stride
,
1
,
stride
}));
std
::
vector
<
std
::
size_t
>
({
batch_
stride
,
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
BLayout
{}));
Tensor
<
ADataType
>
a_g_m_k
(
f_host_tensor_descriptor
(
batch_count
,
M
,
K
,
stride_A
,
batch_stride_A
,
ALayout
{}));
Tensor
<
BDataType
>
b_g_k_n
(
f_host_tensor_descriptor
(
batch_count
,
K
,
N
,
stride_B
,
batch_stride_B
,
BLayout
{}));
Tensor
<
EDataType
>
e_g_m_n_device_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
ELayout
{}));
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
ELayout
{}));
std
::
cout
<<
"a_g_m_k: "
<<
a_g_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_g_k_n: "
<<
b_g_k_n
.
mDesc
<<
std
::
endl
;
...
...
@@ -198,7 +201,7 @@ int main(int argc, char* argv[])
auto
ref_invoker
=
ref_batched_gemm
.
MakeInvoker
();
Tensor
<
EDataType
>
e_g_m_n_host_result
(
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
ELayout
{}));
f_host_tensor_descriptor
(
batch_count
,
M
,
N
,
stride_C
,
batch_stride_C
,
ELayout
{}));
auto
ref_argument
=
ref_batched_gemm
.
MakeArgument
(
a_g_m_k
,
b_g_k_n
,
e_g_m_n_host_result
,
a_element_op
,
b_element_op
,
cde_element_op
);
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp
View file @
e70aa44a
...
...
@@ -14,9 +14,15 @@ struct BatchedGemmCPermuteDesc
ck
::
index_t
stride_G0_
,
stride_G1_
,
stride_M_
,
stride_N_
;
};
template
<
typename
AElementwiseOperation
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
C
DE
ElementwiseOperation
>
struct
DeviceBatchedGemmCPermute
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
...
...
@@ -28,20 +34,36 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
BatchCount
)
=
0
;
CDEElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmCPermutePtr
=
std
::
unique_ptr
<
DeviceBatchedGemmCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
typename
CDEElementwiseOperation
>
using
DeviceBatchedGemmCPermutePtr
=
std
::
unique_ptr
<
DeviceBatchedGemmCPermute
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp
View file @
e70aa44a
...
...
@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp"
...
...
@@ -45,12 +46,12 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
AGridDesc_
A
K0_M_
A
K1
,
typename
BGridDesc_
B
K0_N_
B
K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
...
...
@@ -60,15 +61,15 @@ __global__ void
#endif
kernel_batched_gemm_c_permute_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_
c
_grid
,
FloatC
*
__restrict__
p_
e
_grid
,
const
index_t
batch_count
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
AGridDesc_
A
K0_M_
A
K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_
B
K0_N_
B
K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
C
DE
ElementwiseOperation
c
de
_element_op
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
...
...
@@ -90,11 +91,11 @@ __global__ void
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
ck
::
Tuple
<>
{},
p_
c
_grid
+
c_batch_offset
,
p_
e
_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
c
de
_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ck
::
StaticallyIndexedArray
<
...
...
@@ -105,14 +106,14 @@ __global__ void
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_
c
_grid
;
ignore
=
p_
e
_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c
de
_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
#endif
...
...
@@ -120,48 +121,60 @@ __global__ void
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
index_t
Num
GemmK
Prefetch
Stage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLds
Add
ExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
A
K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLds
Add
ExtraN
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_
B
K1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmCPermuteXdl
:
public
DeviceBatchedGemmCPermute
<
AElementwiseOperation
,
struct
DeviceBatchedGemmCPermuteXdl
:
public
DeviceBatchedGemmCPermute
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
C
DE
ElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmCPermuteXdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -373,7 +386,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
static
auto
Make
C
GridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_M
,
index_t
stride_N
)
Make
E
GridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_M
,
index_t
stride_N
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
...
...
@@ -489,9 +502,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
C
GridDesc_M_N
=
decltype
(
Make
C
GridDescriptor_M_N
(
1
,
1
,
1
,
1
));
using
AGridDesc_
A
K0_M_
A
K1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_
B
K0_N_
B
K1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
E
GridDesc_M_N
=
decltype
(
Make
E
GridDescriptor_M_N
(
1
,
1
,
1
,
1
));
using
EGridDesc_G0_G1_M_N
=
decltype
(
MakeEGridDescriptor_G0_G1_M_N
(
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
));
struct
ComputePtrOffsetOfStridedBatch
...
...
@@ -531,18 +544,18 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
//
CShuffleDataType,
ck
::
Tuple
<>
,
//
DsDataType,
CDataType
,
//
EDataType,
Gemm
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C
DE
ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
C
GridDesc_M_N
,
NumPrefetch
,
AGridDesc_
A
K0_M_
A
K1
,
BGridDesc_
B
K0_N_
B
K1
,
E
GridDesc_M_N
,
Num
GemmK
Prefetch
Stage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
...
...
@@ -553,22 +566,22 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
ABlockTransferDstScalarPerVector_
A
K1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLds
Add
ExtraN
,
BBlockTransferDstScalarPerVector_
B
K1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -576,7 +589,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
LoopSched
>
;
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C
GridDesc_M_N
{}));
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
E
GridDesc_M_N
{}));
using
Block2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
...
...
@@ -584,26 +597,28 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
C
DataType
*
p_
c
_grid
,
E
DataType
*
p_
e
_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
BatchCount
)
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_
c
_grid_
{
p_
c
_grid
},
p_
e
_grid_
{
p_
e
_grid
},
BatchCount_
(
BatchCount
),
a_grid_desc_k0_m_k1_
{
a_grid_desc_
a
k0_m_
a
k1_
{
DeviceBatchedGemmCPermuteXdl
::
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
stride_A
)},
b_grid_desc_k0_n_k1_
{
b_grid_desc_
b
k0_n_
b
k1_
{
DeviceBatchedGemmCPermuteXdl
::
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
stride_B
)},
c
_grid_desc_m_n_
{
DeviceBatchedGemmCPermuteXdl
::
Make
C
GridDescriptor_M_N
(
e
_grid_desc_m_n_
{
DeviceBatchedGemmCPermuteXdl
::
Make
E
GridDescriptor_M_N
(
batched_gemm_c_permute_desc
.
M_
,
batched_gemm_c_permute_desc
.
N_
,
batched_gemm_c_permute_desc
.
stride_M_
,
...
...
@@ -618,42 +633,39 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
batched_gemm_c_permute_desc
.
stride_M_
,
batched_gemm_c_permute_desc
.
stride_N_
)},
c_grid_desc_mblock_mperblock_nblock_nperblock
{},
compute_ptr_offset_of_batch_
{
type_convert
<
index_t
>
(
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
()),
e_grid_desc_g0_g1_m_n_
},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
c_grid_desc_m_n_
)},
compute_ptr_offset_of_batch_
{
batch_stride_A
,
batch_stride_B
,
e_grid_desc_g0_g1_m_n_
},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c
de
_element_op_
{
c
de
_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c
_grid_desc_m_n_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
a
k0_m_
a
k1_
,
b_grid_desc_
b
k0_n_
b
k1_
,
e
_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c
_grid_desc_m_n_
);
e
_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
C
DataType
*
p_
c
_grid_
;
E
DataType
*
p_
e
_grid_
;
index_t
BatchCount_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
C
GridDesc_M_N
c
_grid_desc_m_n_
;
AGridDesc_
A
K0_M_
A
K1
a_grid_desc_
a
k0_m_
a
k1_
;
BGridDesc_
B
K0_N_
B
K1
b_grid_desc_
b
k0_n_
b
k1_
;
E
GridDesc_M_N
e
_grid_desc_m_n_
;
EGridDesc_G0_G1_M_N
e_grid_desc_g0_g1_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
C
DE
ElementwiseOperation
c
de
_element_op_
;
};
// Invoker
...
...
@@ -664,21 +676,23 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.
c
_grid_desc_m_n_{"
<<
arg
.
c
_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c
_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.
e
_grid_desc_m_n_{"
<<
arg
.
e
_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
e
_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c
_grid_desc_m_n_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
a
k0_m_
a
k1_
,
arg
.
b_grid_desc_
b
k0_n_
b
k1_
,
arg
.
e
_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
...
...
@@ -687,10 +701,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c
_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e
_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_
a
k0_m_
a
k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
a
k0_m_
a
k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
...
...
@@ -698,13 +712,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
const
auto
kernel
=
kernel_batched_gemm_c_permute_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
C
DataType
,
remove_reference_t
<
DeviceBatchedGemmCPermuteXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceBatchedGemmCPermuteXdl
::
BGridDesc_K0_N_K1
>
,
E
DataType
,
AGridDesc_
A
K0_M_
A
K1
,
BGridDesc_
B
K0_N_
B
K1
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
C
DE
ElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
has_main_k_block_loop_
>
;
...
...
@@ -716,14 +730,14 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_
c
_grid_
,
arg
.
p_
e
_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_
a
k0_m_
a
k1_
,
arg
.
b_grid_desc_
b
k0_n_
b
k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c
de
_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
};
...
...
@@ -756,31 +770,27 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c
_grid_desc_m_n_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
a
k0_m_
a
k1_
,
arg
.
b_grid_desc_
b
k0_n_
b
k1_
,
arg
.
e
_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
C
DataType
*
p_c
,
E
DataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
BatchCount
)
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -790,11 +800,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
K
,
stride_A
,
stride_B
,
batch_stride_A
,
batch_stride_B
,
batched_gemm_c_permute_desc
,
BatchCount
,
a_element_op
,
b_element_op
,
c_element_op
,
BatchCount
};
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -809,25 +821,29 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
BatchCount
)
override
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
C
DataType
*>
(
p_c
),
static_cast
<
E
DataType
*>
(
p_c
),
M
,
N
,
K
,
stride_A
,
stride_B
,
batch_stride_A
,
batch_stride_B
,
batched_gemm_c_permute_desc
,
BatchCount
,
a_element_op
,
b_element_op
,
c_element_op
,
BatchCount
);
cde_element_op
);
}
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment