Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a59cb900
Commit
a59cb900
authored
Apr 24, 2023
by
ltqin
Browse files
regular code
parent
1cda3b80
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
111 additions
and
146 deletions
+111
-146
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
...nsor_operation_instance/add_device_operation_instance.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
...oftmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp
...mm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp
+49
-67
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp
...mm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp
+50
-67
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
No files found.
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
View file @
a59cb900
...
...
@@ -35,10 +35,10 @@ enum struct ArchitectureEnum
Dl
};
template
<
typename
DeviceOp
,
ArchitectureEnum
Arch
=
ArchitectureEnum
::
Xdl
>
struct
DeviceOperationInstance
Creator
;
struct
DeviceOperationInstance
s
;
template
<
typename
DeviceOp
,
ArchitectureEnum
Arch
>
struct
DeviceOperationInstance
Builde
r
;
struct
DeviceOperationInstance
Creato
r
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
View file @
a59cb900
...
...
@@ -36,7 +36,7 @@ template <index_t NumDimG,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
,
ArchitectureEnum
Arch
>
struct
DeviceOperationInstance
Builde
r
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
struct
DeviceOperationInstance
Creato
r
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
...
...
@@ -75,7 +75,7 @@ struct DeviceOperationInstanceBuilder<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
static
void
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>&
instances
)
{
add_device_operation_instances
(
instances
,
DeviceOperationInstance
Creator
<
DeviceOp
,
Arch
>::
create
_device_instances
());
instances
,
DeviceOperationInstance
s
<
DeviceOp
,
Arch
>::
get
_device_instances
());
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp
View file @
a59cb900
...
...
@@ -13,32 +13,57 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
namespace
bf16_data
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
DataType
,
typename
AccDataType
,
typename
D0DataTypes
,
typename
AD0ElementwiseOp
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
using
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
struct
DeviceOperationInstances
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
,
ArchitectureEnum
::
Xdl
>
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
D0DataTypes
=
Acc0BiasDataType
;
using
AD0ElementwiseOp
=
C0DEElementwiseOperation
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| ype| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
...
...
@@ -60,55 +85,12 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
D0DataTypes
,
ck
::
Tuple
<>
,
AccDataType
,
DataType
,
PassThrough
,
PassThrough
,
AD0ElementwiseOp
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
D0DataTypes
,
ck
::
Tuple
<>
,
AccDataType
,
DataType
,
PassThrough
,
PassThrough
,
AD0ElementwiseOp
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
// clang-format on
>
;
}
// namespace bf16_data
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceCreator
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
,
ArchitectureEnum
::
Xdl
>
{
static
auto
create_device_instances
()
// clang-format on
>
;
static
auto
get_device_instances
()
{
return
bf16_data
::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
bhalf_t
,
float
,
Acc0BiasDataType
,
C0DEElementwiseOperation
,
MaskingSpec
>
{};
return
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
{};
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp
View file @
a59cb900
...
...
@@ -13,32 +13,58 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
namespace
fp16_data
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
DataType
,
typename
AccDataType
,
typename
D0DataTypes
,
typename
AD0ElementwiseOp
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
using
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
struct
DeviceOperationInstances
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
,
ArchitectureEnum
::
Xdl
>
{
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
D0DataTypes
=
Acc0BiasDataType
;
using
AD0ElementwiseOp
=
C0DEElementwiseOperation
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| ype| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
...
...
@@ -60,55 +86,12 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
D0DataTypes
,
ck
::
Tuple
<>
,
AccDataType
,
DataType
,
PassThrough
,
PassThrough
,
AD0ElementwiseOp
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
DataType
,
DataType
,
DataType
,
DataType
,
D0DataTypes
,
ck
::
Tuple
<>
,
AccDataType
,
DataType
,
PassThrough
,
PassThrough
,
AD0ElementwiseOp
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
// clang-format on
>
;
}
// namespace fp16_data
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceCreator
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
,
ArchitectureEnum
::
Xdl
>
{
static
auto
create_device_instances
()
// clang-format on
>
;
static
auto
get_device_instances
()
{
return
fp16_data
::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
half_t
,
float
,
Acc0BiasDataType
,
C0DEElementwiseOperation
,
MaskingSpec
>
{};
return
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
{};
}
};
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
a59cb900
...
...
@@ -66,7 +66,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
@@ -108,7 +108,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
a59cb900
...
...
@@ -65,7 +65,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
@@ -107,7 +107,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
a59cb900
...
...
@@ -65,7 +65,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
@@ -107,7 +107,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
a59cb900
...
...
@@ -65,7 +65,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
...
...
@@ -106,7 +106,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstance
Builde
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
DeviceOperationInstance
Creato
r
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment